diff --git a/.github/workflows/build-and-run.yml b/.github/workflows/build-and-run.yml new file mode 100644 index 00000000..cd84a896 --- /dev/null +++ b/.github/workflows/build-and-run.yml @@ -0,0 +1,163 @@ +name: GPULlama3 Build & Run + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + types: [opened, synchronize, reopened] + +env: + JAVA_HOME: /opt/jenkins/jdks/graal-23.1.0/jdk-21.0.3 + TORNADO_ROOT: ${{ github.workspace }}/GPULlama3.java/external/tornadovm + LLAMA_ROOT: ${{ github.workspace }} + GRAAL_JARS: /opt/graalJars + MODELS_DIR: /opt/models + +jobs: + code-quality: + runs-on: self-hosted + timeout-minutes: 30 + + steps: + - name: Checkout GPULlama3 + uses: actions/checkout@v4 + + - name: Check code formatting (Spotless) + run: | + cd ${{ github.workspace }} + # ./mvnw -T12C -Pspotless spotless:check + + build-and-run: + runs-on: [self-hosted] + needs: code-quality + timeout-minutes: 30 + + strategy: + fail-fast: true + matrix: + backend: + - name: opencl + - name: ptx + + steps: + - name: Checkout GPULlama3 + uses: actions/checkout@v4 + + - name: Clone TornadoVM master + run: | + git clone --depth 1 --branch master \ + https://github.com/beehive-lab/TornadoVM.git \ + $TORNADO_ROOT + - name: Set up Python venv for TornadoVM + run: | + python3 -m venv $TORNADO_ROOT/venv + source $TORNADO_ROOT/venv/bin/activate + python --version + - name: Build TornadoVM + run: | + cd $TORNADO_ROOT + mkdir -p graalJars && cp $GRAAL_JARS/* graalJars/ + source venv/bin/activate + echo "=== Building TornadoVM ===" + + make BACKEND=${{ matrix.backend.name }} + + echo "=== Searching for TornadoVM SDK directory ===" + SDK_DIR=$(find dist -type d -maxdepth 3 -path "*/tornadovm-*-${{ matrix.backend.name }}" | head -n 1) + if [ -z "$SDK_DIR" ]; then + echo "::error::Could not locate TornadoVM SDK directory!" + find dist -maxdepth 5 -type d + exit 1 + fi + FULL_SDK="${PWD}/${SDK_DIR}" + echo "Detected TornadoVM SDK: $FULL_SDK" + + # Export for current shell session + export TORNADO_SDK="$FULL_SDK" + export PATH="$FULL_SDK/bin:$JAVA_HOME/bin:$PATH" + + # Save for subsequent steps + echo "TORNADO_SDK=$FULL_SDK" >> $GITHUB_ENV + echo "PATH=$PATH" >> $GITHUB_ENV + + echo "=== Checking tornado CLI ===" + which tornado || { echo "::error::tornado not in PATH"; exit 1; } + tornado --devices + - name: Build GPULlama3.java + run: | + cd ${{ github.workspace }} + echo "Using TORNADO_SDK=$TORNADO_SDK" + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + tornado --version + ./mvnw clean package -DskipTests + - name: FP16 - Run Llama-3.2-1B-Instruct-F16.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-F16.gguf \ + --prompt "Say hello" + - name: FP16 - Run Qwen3-4B-f16.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/Qwen3-4B-f16.gguf \ + --prompt "Say hello" + - name: FP16 - Run Mistral-7B-Instruct-v0.3.fp16.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/Mistral-7B-Instruct-v0.3.fp16.gguf \ + --prompt "Say hello" + - name: FP16 - Run Qwen2.5-1.5b-instruct-fp16.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/qwen2.5-1.5b-instruct-fp16.gguf \ + --prompt "Say hello" + - name: FP16 - Run Phi-3-mini-4k-instruct-fp16.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model /$MODELS_DIR/Phi-3-mini-4k-instruct-fp16.gguf \ + --prompt "Say hello" + - name: Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \ + --prompt "Say hello" + - name: Q8 - Run Qwen3-0.6B-Q8_0.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/Qwen3-0.6B-Q8_0.gguf \ + --prompt "Say hello" + - name: Q8 - Run Phi-3-mini-4k-instruct-Q8_0.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/Phi-3-mini-4k-instruct-Q8_0.gguf \ + --prompt "Say hello" + - name: Q8 - Run Qwen2.5-1.5b-instruct-q8_0.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/qwen2.5-1.5b-instruct-q8_0.gguf \ + --prompt "Say hello" + - name: Q8 - Mistral-7B-Instruct-v0.3.Q8_0.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model $MODELS_DIR/Mistral-7B-Instruct-v0.3.Q8_0.gguf \ + --prompt "Say hello" diff --git a/.github/workflows/deploy-maven-central.yml b/.github/workflows/deploy-maven-central.yml new file mode 100644 index 00000000..a723e83c --- /dev/null +++ b/.github/workflows/deploy-maven-central.yml @@ -0,0 +1,88 @@ +name: Deploy to Maven Central + +on: + push: + tags: + - 'v*' + - '[0-9]+.[0-9]+.[0-9]+*' + workflow_dispatch: + inputs: + dry_run: + description: 'Dry run (skip actual deploy)' + required: false + default: false + type: boolean + +jobs: + deploy: + name: Deploy to Maven Central + runs-on: [self-hosted, Linux, x64] + timeout-minutes: 15 + env: + JAVA_HOME: /opt/jenkins/jdks/graal-23.1.0/jdk-21.0.3 + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup environment + run: | + echo "$JAVA_HOME/bin" >> $GITHUB_PATH + + - name: Configure Maven settings + run: | + mkdir -p ~/.m2 + cat > ~/.m2/settings.xml << 'EOF' + + + + central + ${env.MAVEN_USERNAME} + ${env.MAVEN_PASSWORD} + + + + + gpg + + gpg + ${env.GPG_KEYNAME} + ${env.GPG_PASSPHRASE} + + + + + gpg + + + EOF + + - name: Import GPG key + run: | + echo "${{ secrets.GPG_PRIVATE_KEY }}" | gpg --batch --import + env: + GPG_TTY: $(tty) + + - name: Deploy to Maven Central + if: ${{ !inputs.dry_run }} + run: | + ./mvnw clean deploy \ + -P release \ + -DskipTests \ + --batch-mode + env: + MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }} + MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }} + GPG_KEYNAME: ${{ secrets.GPG_KEYNAME }} + GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} + + - name: Dry Run - Verify build only + if: ${{ inputs.dry_run }} + run: | + ./mvnw clean verify \ + -P release \ + -DskipTests \ + --batch-mode + env: + GPG_KEYNAME: ${{ secrets.GPG_KEYNAME }} + GPG_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} \ No newline at end of file diff --git a/.github/workflows/rerun-workflow.yml b/.github/workflows/rerun-workflow.yml new file mode 100644 index 00000000..6891ad92 --- /dev/null +++ b/.github/workflows/rerun-workflow.yml @@ -0,0 +1,181 @@ +name: Rerun Workflows + +on: + issue_comment: + types: [created] + +jobs: + rerun: + name: Rerun CI Workflows + # Only run on PR comments (not issue comments) with /rerun command + if: | + github.event.issue.pull_request && + contains(github.event.comment.body, '/rerun') + runs-on: ubuntu-latest + permissions: + actions: write + pull-requests: write + contents: read + + steps: + - name: Check for help command + id: help + uses: actions/github-script@v7 + with: + script: | + const comment = context.payload.comment.body; + if (comment.match(/\/rerun\s+help/i)) { + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: `## 🔄 Rerun Workflow Commands + + | Command | Description | + |---------|-------------| + | \`/rerun\` | Rerun only **failed/cancelled/timed-out** workflows | + | \`/rerun all\` | Rerun **all** workflows for this PR | + | \`/rerun failed\` | Same as \`/rerun\` | + | \`/rerun \` | Rerun workflows matching \`\` (e.g. \`/rerun ci\`, \`/rerun build\`) | + | \`/rerun help\` | Show this help message | + + **Note:** Only completed workflows can be rerun. In-progress workflows are skipped.` + }); + core.setOutput('is_help', 'true'); + } else { + core.setOutput('is_help', 'false'); + } + + - name: Get PR SHA + if: steps.help.outputs.is_help != 'true' + id: pr + uses: actions/github-script@v7 + with: + script: | + const { data: pr } = await github.rest.pulls.get({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: context.issue.number + }); + core.setOutput('sha', pr.head.sha); + core.setOutput('head_ref', pr.head.ref); + console.log(`PR #${context.issue.number} SHA: ${pr.head.sha}`); + console.log(`PR head ref: ${pr.head.ref}`); + + - name: Add reaction to comment + if: steps.help.outputs.is_help != 'true' + uses: actions/github-script@v7 + with: + script: | + await github.rest.reactions.createForIssueComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: context.payload.comment.id, + content: 'rocket' + }); + + - name: Post start comment + if: steps.help.outputs.is_help != 'true' + uses: actions/github-script@v7 + with: + script: | + const comment = context.payload.comment.body; + const rerunMatch = comment.match(/\/rerun\s*(\S+)?/); + const rerunArg = rerunMatch && rerunMatch[1] ? rerunMatch[1] : 'failed'; + + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: `🚀 **Workflow rerun started**\n\nMode: \`${rerunArg}\`\nTriggered by: @${context.payload.comment.user.login}\n\n[View Actions](https://github.com/${context.repo.owner}/${context.repo.repo}/actions)` + }); + + - name: Rerun failed workflows + if: steps.help.outputs.is_help != 'true' + uses: actions/github-script@v7 + with: + script: | + const sha = '${{ steps.pr.outputs.sha }}'; + const headRef = '${{ steps.pr.outputs.head_ref }}'; + + // Get all workflow runs for this PR's head SHA + const { data: runs } = await github.rest.actions.listWorkflowRunsForRepo({ + owner: context.repo.owner, + repo: context.repo.repo, + head_sha: sha, + per_page: 100 + }); + + console.log(`Found ${runs.total_count} workflow runs for SHA ${sha}`); + + if (runs.total_count === 0) { + console.log('No workflow runs found for this PR'); + return; + } + + // Parse command for specific workflow filter + // Supports: /rerun, /rerun all, /rerun failed, /rerun + const comment = context.payload.comment.body; + const rerunMatch = comment.match(/\/rerun\s*(\S+)?/); + const rerunArg = rerunMatch && rerunMatch[1] ? rerunMatch[1].toLowerCase() : 'failed'; + + console.log(`Rerun mode: ${rerunArg}`); + + let rerunCount = 0; + + for (const run of runs.workflow_runs) { + const shouldRerun = + rerunArg === 'all' || + (rerunArg === 'failed' && ['failure', 'cancelled', 'timed_out'].includes(run.conclusion)) || + run.name.toLowerCase().includes(rerunArg); + + if (!shouldRerun) { + console.log(`Skipping ${run.name} (status: ${run.status}, conclusion: ${run.conclusion})`); + continue; + } + + // Only rerun completed workflows + if (run.status !== 'completed') { + console.log(`Skipping ${run.name} - still ${run.status}`); + continue; + } + + try { + console.log(`Rerunning workflow: ${run.name} (ID: ${run.id})`); + + // Use rerun-failed-jobs if available and workflow failed, otherwise full rerun + if (['failure', 'cancelled', 'timed_out'].includes(run.conclusion)) { + await github.rest.actions.reRunWorkflowFailedJobs({ + owner: context.repo.owner, + repo: context.repo.repo, + run_id: run.id + }); + } else { + await github.rest.actions.reRunWorkflow({ + owner: context.repo.owner, + repo: context.repo.repo, + run_id: run.id + }); + } + rerunCount++; + } catch (error) { + console.log(`Failed to rerun ${run.name}: ${error.message}`); + } + } + + console.log(`Reran ${rerunCount} workflow(s)`); + + - name: Post completion comment + if: always() && steps.help.outputs.is_help != 'true' + uses: actions/github-script@v7 + with: + script: | + const status = '${{ job.status }}'; + const emoji = status === 'success' ? '✅' : '❌'; + + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: `${emoji} **Workflow rerun ${status}**\n\n[View Actions](https://github.com/${context.repo.owner}/${context.repo.repo}/actions)` + }); \ No newline at end of file diff --git a/.mvn/wrapper/maven-wrapper.properties b/.mvn/wrapper/maven-wrapper.properties new file mode 100644 index 00000000..7c6b218b --- /dev/null +++ b/.mvn/wrapper/maven-wrapper.properties @@ -0,0 +1,3 @@ +wrapperVersion=3.3.4 +distributionType=only-script +distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.9.6/apache-maven-3.9.6-bin.zip diff --git a/Makefile b/Makefile index 752659b8..dd0eac84 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,9 @@ # Simple Makefile for Maven build without tests .PHONY: build clean package help +# Maven wrapper +MVN = ./mvnw + # Default target all: package @@ -9,16 +12,24 @@ build: clean package # Clean the project clean: - mvn clean + $(MVN) clean # Package the project without running tests package: - mvn package -DskipTests + $(MVN) package -DskipTests # Combined clean and package package-with-clean: - mvn clean package -DskipTests + $(MVN) clean package -DskipTests + +lint: + $(MVN) -T12C -Pspotless spotless:check + +# Automatically format the code to conform to a style guide. +# Modifies the code to ensure consistent formatting. +format: + $(MVN) -T12C -Pspotless spotless:apply # Display help help: diff --git a/README.md b/README.md index 6b8e8167..04dbdbe4 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# GPULlama3.java powered by TornadoVM +# GPULlama3.java powered by TornadoVM [![GPULlama3 Build & Run Inference](https://github.com/beehive-lab/GPULlama3.java/actions/workflows/build-and-run.yml/badge.svg)](https://github.com/beehive-lab/GPULlama3.java/actions/workflows/build-and-run.yml) ![Java Version](https://img.shields.io/badge/java-21+-blue?style=for-the-badge&logo=openjdk) ![OpenCL](https://img.shields.io/badge/OpenCL-supported-blue?style=for-the-badge&logo=khronos) ![CUDA](https://img.shields.io/badge/CUDA/PTX-supported-76B900?style=for-the-badge&logo=nvidia) diff --git a/docs/GPULlama3_ROADMAP.md b/docs/GPULlama3_ROADMAP.md index 44346210..4c18a9d5 100644 --- a/docs/GPULlama3_ROADMAP.md +++ b/docs/GPULlama3_ROADMAP.md @@ -2,9 +2,9 @@ - [Pending Merge] **LangChain4j integration** - [ ] **Additional quantization formats** - - [ ] Q8 + - [x] Q8 - [ ] Q4 - - [ ] INT8 native support for GPUs + - [x] INT8 native support for GPUs - [ ] **Additional architectures and model format** - [x] Mistral/Mixtral models - [x] Qwen @@ -20,5 +20,4 @@ - [ ] **Performance optimizations** - [ ] Multi-GPU support - [X] Memory-efficient attention mechanisms - - [ ] More Kernel fusion improvements -- [ ] **GraalVM Native Image** + - [x] More Kernel fusion improvements diff --git a/external/tornadovm b/external/tornadovm index 4a8b990b..f6de88c1 160000 --- a/external/tornadovm +++ b/external/tornadovm @@ -1 +1 @@ -Subproject commit 4a8b990b6d0196339a294f155ea6c52421a7cbe4 +Subproject commit f6de88c150117d17ddc04a749e34f7f4ac4d0429 diff --git a/llama-tornado b/llama-tornado index b59473f2..9c0d6ba8 100755 --- a/llama-tornado +++ b/llama-tornado @@ -410,7 +410,7 @@ def create_parser() -> argparse.ArgumentParser: const=Backend.PTX, help="Use PTX/CUDA backend", ) - hw_group.add_argument("--gpu-memory", default="7GB", help="GPU memory allocation") + hw_group.add_argument("--gpu-memory", default="14GB", help="GPU memory allocation") hw_group.add_argument("--heap-min", default="20g", help="Minimum JVM heap size") hw_group.add_argument("--heap-max", default="20g", help="Maximum JVM heap size") diff --git a/mvnw b/mvnw new file mode 100755 index 00000000..bd8896bf --- /dev/null +++ b/mvnw @@ -0,0 +1,295 @@ +#!/bin/sh +# ---------------------------------------------------------------------------- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- +# Apache Maven Wrapper startup batch script, version 3.3.4 +# +# Optional ENV vars +# ----------------- +# JAVA_HOME - location of a JDK home dir, required when download maven via java source +# MVNW_REPOURL - repo url base for downloading maven distribution +# MVNW_USERNAME/MVNW_PASSWORD - user and password for downloading maven +# MVNW_VERBOSE - true: enable verbose log; debug: trace the mvnw script; others: silence the output +# ---------------------------------------------------------------------------- + +set -euf +[ "${MVNW_VERBOSE-}" != debug ] || set -x + +# OS specific support. +native_path() { printf %s\\n "$1"; } +case "$(uname)" in +CYGWIN* | MINGW*) + [ -z "${JAVA_HOME-}" ] || JAVA_HOME="$(cygpath --unix "$JAVA_HOME")" + native_path() { cygpath --path --windows "$1"; } + ;; +esac + +# set JAVACMD and JAVACCMD +set_java_home() { + # For Cygwin and MinGW, ensure paths are in Unix format before anything is touched + if [ -n "${JAVA_HOME-}" ]; then + if [ -x "$JAVA_HOME/jre/sh/java" ]; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + JAVACCMD="$JAVA_HOME/jre/sh/javac" + else + JAVACMD="$JAVA_HOME/bin/java" + JAVACCMD="$JAVA_HOME/bin/javac" + + if [ ! -x "$JAVACMD" ] || [ ! -x "$JAVACCMD" ]; then + echo "The JAVA_HOME environment variable is not defined correctly, so mvnw cannot run." >&2 + echo "JAVA_HOME is set to \"$JAVA_HOME\", but \"\$JAVA_HOME/bin/java\" or \"\$JAVA_HOME/bin/javac\" does not exist." >&2 + return 1 + fi + fi + else + JAVACMD="$( + 'set' +e + 'unset' -f command 2>/dev/null + 'command' -v java + )" || : + JAVACCMD="$( + 'set' +e + 'unset' -f command 2>/dev/null + 'command' -v javac + )" || : + + if [ ! -x "${JAVACMD-}" ] || [ ! -x "${JAVACCMD-}" ]; then + echo "The java/javac command does not exist in PATH nor is JAVA_HOME set, so mvnw cannot run." >&2 + return 1 + fi + fi +} + +# hash string like Java String::hashCode +hash_string() { + str="${1:-}" h=0 + while [ -n "$str" ]; do + char="${str%"${str#?}"}" + h=$(((h * 31 + $(LC_CTYPE=C printf %d "'$char")) % 4294967296)) + str="${str#?}" + done + printf %x\\n $h +} + +verbose() { :; } +[ "${MVNW_VERBOSE-}" != true ] || verbose() { printf %s\\n "${1-}"; } + +die() { + printf %s\\n "$1" >&2 + exit 1 +} + +trim() { + # MWRAPPER-139: + # Trims trailing and leading whitespace, carriage returns, tabs, and linefeeds. + # Needed for removing poorly interpreted newline sequences when running in more + # exotic environments such as mingw bash on Windows. + printf "%s" "${1}" | tr -d '[:space:]' +} + +scriptDir="$(dirname "$0")" +scriptName="$(basename "$0")" + +# parse distributionUrl and optional distributionSha256Sum, requires .mvn/wrapper/maven-wrapper.properties +while IFS="=" read -r key value; do + case "${key-}" in + distributionUrl) distributionUrl=$(trim "${value-}") ;; + distributionSha256Sum) distributionSha256Sum=$(trim "${value-}") ;; + esac +done <"$scriptDir/.mvn/wrapper/maven-wrapper.properties" +[ -n "${distributionUrl-}" ] || die "cannot read distributionUrl property in $scriptDir/.mvn/wrapper/maven-wrapper.properties" + +case "${distributionUrl##*/}" in +maven-mvnd-*bin.*) + MVN_CMD=mvnd.sh _MVNW_REPO_PATTERN=/maven/mvnd/ + case "${PROCESSOR_ARCHITECTURE-}${PROCESSOR_ARCHITEW6432-}:$(uname -a)" in + *AMD64:CYGWIN* | *AMD64:MINGW*) distributionPlatform=windows-amd64 ;; + :Darwin*x86_64) distributionPlatform=darwin-amd64 ;; + :Darwin*arm64) distributionPlatform=darwin-aarch64 ;; + :Linux*x86_64*) distributionPlatform=linux-amd64 ;; + *) + echo "Cannot detect native platform for mvnd on $(uname)-$(uname -m), use pure java version" >&2 + distributionPlatform=linux-amd64 + ;; + esac + distributionUrl="${distributionUrl%-bin.*}-$distributionPlatform.zip" + ;; +maven-mvnd-*) MVN_CMD=mvnd.sh _MVNW_REPO_PATTERN=/maven/mvnd/ ;; +*) MVN_CMD="mvn${scriptName#mvnw}" _MVNW_REPO_PATTERN=/org/apache/maven/ ;; +esac + +# apply MVNW_REPOURL and calculate MAVEN_HOME +# maven home pattern: ~/.m2/wrapper/dists/{apache-maven-,maven-mvnd--}/ +[ -z "${MVNW_REPOURL-}" ] || distributionUrl="$MVNW_REPOURL$_MVNW_REPO_PATTERN${distributionUrl#*"$_MVNW_REPO_PATTERN"}" +distributionUrlName="${distributionUrl##*/}" +distributionUrlNameMain="${distributionUrlName%.*}" +distributionUrlNameMain="${distributionUrlNameMain%-bin}" +MAVEN_USER_HOME="${MAVEN_USER_HOME:-${HOME}/.m2}" +MAVEN_HOME="${MAVEN_USER_HOME}/wrapper/dists/${distributionUrlNameMain-}/$(hash_string "$distributionUrl")" + +exec_maven() { + unset MVNW_VERBOSE MVNW_USERNAME MVNW_PASSWORD MVNW_REPOURL || : + exec "$MAVEN_HOME/bin/$MVN_CMD" "$@" || die "cannot exec $MAVEN_HOME/bin/$MVN_CMD" +} + +if [ -d "$MAVEN_HOME" ]; then + verbose "found existing MAVEN_HOME at $MAVEN_HOME" + exec_maven "$@" +fi + +case "${distributionUrl-}" in +*?-bin.zip | *?maven-mvnd-?*-?*.zip) ;; +*) die "distributionUrl is not valid, must match *-bin.zip or maven-mvnd-*.zip, but found '${distributionUrl-}'" ;; +esac + +# prepare tmp dir +if TMP_DOWNLOAD_DIR="$(mktemp -d)" && [ -d "$TMP_DOWNLOAD_DIR" ]; then + clean() { rm -rf -- "$TMP_DOWNLOAD_DIR"; } + trap clean HUP INT TERM EXIT +else + die "cannot create temp dir" +fi + +mkdir -p -- "${MAVEN_HOME%/*}" + +# Download and Install Apache Maven +verbose "Couldn't find MAVEN_HOME, downloading and installing it ..." +verbose "Downloading from: $distributionUrl" +verbose "Downloading to: $TMP_DOWNLOAD_DIR/$distributionUrlName" + +# select .zip or .tar.gz +if ! command -v unzip >/dev/null; then + distributionUrl="${distributionUrl%.zip}.tar.gz" + distributionUrlName="${distributionUrl##*/}" +fi + +# verbose opt +__MVNW_QUIET_WGET=--quiet __MVNW_QUIET_CURL=--silent __MVNW_QUIET_UNZIP=-q __MVNW_QUIET_TAR='' +[ "${MVNW_VERBOSE-}" != true ] || __MVNW_QUIET_WGET='' __MVNW_QUIET_CURL='' __MVNW_QUIET_UNZIP='' __MVNW_QUIET_TAR=v + +# normalize http auth +case "${MVNW_PASSWORD:+has-password}" in +'') MVNW_USERNAME='' MVNW_PASSWORD='' ;; +has-password) [ -n "${MVNW_USERNAME-}" ] || MVNW_USERNAME='' MVNW_PASSWORD='' ;; +esac + +if [ -z "${MVNW_USERNAME-}" ] && command -v wget >/dev/null; then + verbose "Found wget ... using wget" + wget ${__MVNW_QUIET_WGET:+"$__MVNW_QUIET_WGET"} "$distributionUrl" -O "$TMP_DOWNLOAD_DIR/$distributionUrlName" || die "wget: Failed to fetch $distributionUrl" +elif [ -z "${MVNW_USERNAME-}" ] && command -v curl >/dev/null; then + verbose "Found curl ... using curl" + curl ${__MVNW_QUIET_CURL:+"$__MVNW_QUIET_CURL"} -f -L -o "$TMP_DOWNLOAD_DIR/$distributionUrlName" "$distributionUrl" || die "curl: Failed to fetch $distributionUrl" +elif set_java_home; then + verbose "Falling back to use Java to download" + javaSource="$TMP_DOWNLOAD_DIR/Downloader.java" + targetZip="$TMP_DOWNLOAD_DIR/$distributionUrlName" + cat >"$javaSource" <<-END + public class Downloader extends java.net.Authenticator + { + protected java.net.PasswordAuthentication getPasswordAuthentication() + { + return new java.net.PasswordAuthentication( System.getenv( "MVNW_USERNAME" ), System.getenv( "MVNW_PASSWORD" ).toCharArray() ); + } + public static void main( String[] args ) throws Exception + { + setDefault( new Downloader() ); + java.nio.file.Files.copy( java.net.URI.create( args[0] ).toURL().openStream(), java.nio.file.Paths.get( args[1] ).toAbsolutePath().normalize() ); + } + } + END + # For Cygwin/MinGW, switch paths to Windows format before running javac and java + verbose " - Compiling Downloader.java ..." + "$(native_path "$JAVACCMD")" "$(native_path "$javaSource")" || die "Failed to compile Downloader.java" + verbose " - Running Downloader.java ..." + "$(native_path "$JAVACMD")" -cp "$(native_path "$TMP_DOWNLOAD_DIR")" Downloader "$distributionUrl" "$(native_path "$targetZip")" +fi + +# If specified, validate the SHA-256 sum of the Maven distribution zip file +if [ -n "${distributionSha256Sum-}" ]; then + distributionSha256Result=false + if [ "$MVN_CMD" = mvnd.sh ]; then + echo "Checksum validation is not supported for maven-mvnd." >&2 + echo "Please disable validation by removing 'distributionSha256Sum' from your maven-wrapper.properties." >&2 + exit 1 + elif command -v sha256sum >/dev/null; then + if echo "$distributionSha256Sum $TMP_DOWNLOAD_DIR/$distributionUrlName" | sha256sum -c - >/dev/null 2>&1; then + distributionSha256Result=true + fi + elif command -v shasum >/dev/null; then + if echo "$distributionSha256Sum $TMP_DOWNLOAD_DIR/$distributionUrlName" | shasum -a 256 -c >/dev/null 2>&1; then + distributionSha256Result=true + fi + else + echo "Checksum validation was requested but neither 'sha256sum' or 'shasum' are available." >&2 + echo "Please install either command, or disable validation by removing 'distributionSha256Sum' from your maven-wrapper.properties." >&2 + exit 1 + fi + if [ $distributionSha256Result = false ]; then + echo "Error: Failed to validate Maven distribution SHA-256, your Maven distribution might be compromised." >&2 + echo "If you updated your Maven version, you need to update the specified distributionSha256Sum property." >&2 + exit 1 + fi +fi + +# unzip and move +if command -v unzip >/dev/null; then + unzip ${__MVNW_QUIET_UNZIP:+"$__MVNW_QUIET_UNZIP"} "$TMP_DOWNLOAD_DIR/$distributionUrlName" -d "$TMP_DOWNLOAD_DIR" || die "failed to unzip" +else + tar xzf${__MVNW_QUIET_TAR:+"$__MVNW_QUIET_TAR"} "$TMP_DOWNLOAD_DIR/$distributionUrlName" -C "$TMP_DOWNLOAD_DIR" || die "failed to untar" +fi + +# Find the actual extracted directory name (handles snapshots where filename != directory name) +actualDistributionDir="" + +# First try the expected directory name (for regular distributions) +if [ -d "$TMP_DOWNLOAD_DIR/$distributionUrlNameMain" ]; then + if [ -f "$TMP_DOWNLOAD_DIR/$distributionUrlNameMain/bin/$MVN_CMD" ]; then + actualDistributionDir="$distributionUrlNameMain" + fi +fi + +# If not found, search for any directory with the Maven executable (for snapshots) +if [ -z "$actualDistributionDir" ]; then + # enable globbing to iterate over items + set +f + for dir in "$TMP_DOWNLOAD_DIR"/*; do + if [ -d "$dir" ]; then + if [ -f "$dir/bin/$MVN_CMD" ]; then + actualDistributionDir="$(basename "$dir")" + break + fi + fi + done + set -f +fi + +if [ -z "$actualDistributionDir" ]; then + verbose "Contents of $TMP_DOWNLOAD_DIR:" + verbose "$(ls -la "$TMP_DOWNLOAD_DIR")" + die "Could not find Maven distribution directory in extracted archive" +fi + +verbose "Found extracted Maven distribution directory: $actualDistributionDir" +printf %s\\n "$distributionUrl" >"$TMP_DOWNLOAD_DIR/$actualDistributionDir/mvnw.url" +mv -- "$TMP_DOWNLOAD_DIR/$actualDistributionDir" "$MAVEN_HOME" || [ -d "$MAVEN_HOME" ] || die "fail to move MAVEN_HOME" + +clean || : +exec_maven "$@" diff --git a/mvnw.cmd b/mvnw.cmd new file mode 100644 index 00000000..5761d948 --- /dev/null +++ b/mvnw.cmd @@ -0,0 +1,189 @@ +<# : batch portion +@REM ---------------------------------------------------------------------------- +@REM Licensed to the Apache Software Foundation (ASF) under one +@REM or more contributor license agreements. See the NOTICE file +@REM distributed with this work for additional information +@REM regarding copyright ownership. The ASF licenses this file +@REM to you under the Apache License, Version 2.0 (the +@REM "License"); you may not use this file except in compliance +@REM with the License. You may obtain a copy of the License at +@REM +@REM http://www.apache.org/licenses/LICENSE-2.0 +@REM +@REM Unless required by applicable law or agreed to in writing, +@REM software distributed under the License is distributed on an +@REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +@REM KIND, either express or implied. See the License for the +@REM specific language governing permissions and limitations +@REM under the License. +@REM ---------------------------------------------------------------------------- + +@REM ---------------------------------------------------------------------------- +@REM Apache Maven Wrapper startup batch script, version 3.3.4 +@REM +@REM Optional ENV vars +@REM MVNW_REPOURL - repo url base for downloading maven distribution +@REM MVNW_USERNAME/MVNW_PASSWORD - user and password for downloading maven +@REM MVNW_VERBOSE - true: enable verbose log; others: silence the output +@REM ---------------------------------------------------------------------------- + +@IF "%__MVNW_ARG0_NAME__%"=="" (SET __MVNW_ARG0_NAME__=%~nx0) +@SET __MVNW_CMD__= +@SET __MVNW_ERROR__= +@SET __MVNW_PSMODULEP_SAVE=%PSModulePath% +@SET PSModulePath= +@FOR /F "usebackq tokens=1* delims==" %%A IN (`powershell -noprofile "& {$scriptDir='%~dp0'; $script='%__MVNW_ARG0_NAME__%'; icm -ScriptBlock ([Scriptblock]::Create((Get-Content -Raw '%~f0'))) -NoNewScope}"`) DO @( + IF "%%A"=="MVN_CMD" (set __MVNW_CMD__=%%B) ELSE IF "%%B"=="" (echo %%A) ELSE (echo %%A=%%B) +) +@SET PSModulePath=%__MVNW_PSMODULEP_SAVE% +@SET __MVNW_PSMODULEP_SAVE= +@SET __MVNW_ARG0_NAME__= +@SET MVNW_USERNAME= +@SET MVNW_PASSWORD= +@IF NOT "%__MVNW_CMD__%"=="" ("%__MVNW_CMD__%" %*) +@echo Cannot start maven from wrapper >&2 && exit /b 1 +@GOTO :EOF +: end batch / begin powershell #> + +$ErrorActionPreference = "Stop" +if ($env:MVNW_VERBOSE -eq "true") { + $VerbosePreference = "Continue" +} + +# calculate distributionUrl, requires .mvn/wrapper/maven-wrapper.properties +$distributionUrl = (Get-Content -Raw "$scriptDir/.mvn/wrapper/maven-wrapper.properties" | ConvertFrom-StringData).distributionUrl +if (!$distributionUrl) { + Write-Error "cannot read distributionUrl property in $scriptDir/.mvn/wrapper/maven-wrapper.properties" +} + +switch -wildcard -casesensitive ( $($distributionUrl -replace '^.*/','') ) { + "maven-mvnd-*" { + $USE_MVND = $true + $distributionUrl = $distributionUrl -replace '-bin\.[^.]*$',"-windows-amd64.zip" + $MVN_CMD = "mvnd.cmd" + break + } + default { + $USE_MVND = $false + $MVN_CMD = $script -replace '^mvnw','mvn' + break + } +} + +# apply MVNW_REPOURL and calculate MAVEN_HOME +# maven home pattern: ~/.m2/wrapper/dists/{apache-maven-,maven-mvnd--}/ +if ($env:MVNW_REPOURL) { + $MVNW_REPO_PATTERN = if ($USE_MVND -eq $False) { "/org/apache/maven/" } else { "/maven/mvnd/" } + $distributionUrl = "$env:MVNW_REPOURL$MVNW_REPO_PATTERN$($distributionUrl -replace "^.*$MVNW_REPO_PATTERN",'')" +} +$distributionUrlName = $distributionUrl -replace '^.*/','' +$distributionUrlNameMain = $distributionUrlName -replace '\.[^.]*$','' -replace '-bin$','' + +$MAVEN_M2_PATH = "$HOME/.m2" +if ($env:MAVEN_USER_HOME) { + $MAVEN_M2_PATH = "$env:MAVEN_USER_HOME" +} + +if (-not (Test-Path -Path $MAVEN_M2_PATH)) { + New-Item -Path $MAVEN_M2_PATH -ItemType Directory | Out-Null +} + +$MAVEN_WRAPPER_DISTS = $null +if ((Get-Item $MAVEN_M2_PATH).Target[0] -eq $null) { + $MAVEN_WRAPPER_DISTS = "$MAVEN_M2_PATH/wrapper/dists" +} else { + $MAVEN_WRAPPER_DISTS = (Get-Item $MAVEN_M2_PATH).Target[0] + "/wrapper/dists" +} + +$MAVEN_HOME_PARENT = "$MAVEN_WRAPPER_DISTS/$distributionUrlNameMain" +$MAVEN_HOME_NAME = ([System.Security.Cryptography.SHA256]::Create().ComputeHash([byte[]][char[]]$distributionUrl) | ForEach-Object {$_.ToString("x2")}) -join '' +$MAVEN_HOME = "$MAVEN_HOME_PARENT/$MAVEN_HOME_NAME" + +if (Test-Path -Path "$MAVEN_HOME" -PathType Container) { + Write-Verbose "found existing MAVEN_HOME at $MAVEN_HOME" + Write-Output "MVN_CMD=$MAVEN_HOME/bin/$MVN_CMD" + exit $? +} + +if (! $distributionUrlNameMain -or ($distributionUrlName -eq $distributionUrlNameMain)) { + Write-Error "distributionUrl is not valid, must end with *-bin.zip, but found $distributionUrl" +} + +# prepare tmp dir +$TMP_DOWNLOAD_DIR_HOLDER = New-TemporaryFile +$TMP_DOWNLOAD_DIR = New-Item -Itemtype Directory -Path "$TMP_DOWNLOAD_DIR_HOLDER.dir" +$TMP_DOWNLOAD_DIR_HOLDER.Delete() | Out-Null +trap { + if ($TMP_DOWNLOAD_DIR.Exists) { + try { Remove-Item $TMP_DOWNLOAD_DIR -Recurse -Force | Out-Null } + catch { Write-Warning "Cannot remove $TMP_DOWNLOAD_DIR" } + } +} + +New-Item -Itemtype Directory -Path "$MAVEN_HOME_PARENT" -Force | Out-Null + +# Download and Install Apache Maven +Write-Verbose "Couldn't find MAVEN_HOME, downloading and installing it ..." +Write-Verbose "Downloading from: $distributionUrl" +Write-Verbose "Downloading to: $TMP_DOWNLOAD_DIR/$distributionUrlName" + +$webclient = New-Object System.Net.WebClient +if ($env:MVNW_USERNAME -and $env:MVNW_PASSWORD) { + $webclient.Credentials = New-Object System.Net.NetworkCredential($env:MVNW_USERNAME, $env:MVNW_PASSWORD) +} +[Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12 +$webclient.DownloadFile($distributionUrl, "$TMP_DOWNLOAD_DIR/$distributionUrlName") | Out-Null + +# If specified, validate the SHA-256 sum of the Maven distribution zip file +$distributionSha256Sum = (Get-Content -Raw "$scriptDir/.mvn/wrapper/maven-wrapper.properties" | ConvertFrom-StringData).distributionSha256Sum +if ($distributionSha256Sum) { + if ($USE_MVND) { + Write-Error "Checksum validation is not supported for maven-mvnd. `nPlease disable validation by removing 'distributionSha256Sum' from your maven-wrapper.properties." + } + Import-Module $PSHOME\Modules\Microsoft.PowerShell.Utility -Function Get-FileHash + if ((Get-FileHash "$TMP_DOWNLOAD_DIR/$distributionUrlName" -Algorithm SHA256).Hash.ToLower() -ne $distributionSha256Sum) { + Write-Error "Error: Failed to validate Maven distribution SHA-256, your Maven distribution might be compromised. If you updated your Maven version, you need to update the specified distributionSha256Sum property." + } +} + +# unzip and move +Expand-Archive "$TMP_DOWNLOAD_DIR/$distributionUrlName" -DestinationPath "$TMP_DOWNLOAD_DIR" | Out-Null + +# Find the actual extracted directory name (handles snapshots where filename != directory name) +$actualDistributionDir = "" + +# First try the expected directory name (for regular distributions) +$expectedPath = Join-Path "$TMP_DOWNLOAD_DIR" "$distributionUrlNameMain" +$expectedMvnPath = Join-Path "$expectedPath" "bin/$MVN_CMD" +if ((Test-Path -Path $expectedPath -PathType Container) -and (Test-Path -Path $expectedMvnPath -PathType Leaf)) { + $actualDistributionDir = $distributionUrlNameMain +} + +# If not found, search for any directory with the Maven executable (for snapshots) +if (!$actualDistributionDir) { + Get-ChildItem -Path "$TMP_DOWNLOAD_DIR" -Directory | ForEach-Object { + $testPath = Join-Path $_.FullName "bin/$MVN_CMD" + if (Test-Path -Path $testPath -PathType Leaf) { + $actualDistributionDir = $_.Name + } + } +} + +if (!$actualDistributionDir) { + Write-Error "Could not find Maven distribution directory in extracted archive" +} + +Write-Verbose "Found extracted Maven distribution directory: $actualDistributionDir" +Rename-Item -Path "$TMP_DOWNLOAD_DIR/$actualDistributionDir" -NewName $MAVEN_HOME_NAME | Out-Null +try { + Move-Item -Path "$TMP_DOWNLOAD_DIR/$MAVEN_HOME_NAME" -Destination $MAVEN_HOME_PARENT | Out-Null +} catch { + if (! (Test-Path -Path "$MAVEN_HOME" -PathType Container)) { + Write-Error "fail to move MAVEN_HOME" + } +} finally { + try { Remove-Item $TMP_DOWNLOAD_DIR -Recurse -Force | Out-Null } + catch { Write-Warning "Cannot remove $TMP_DOWNLOAD_DIR" } +} + +Write-Output "MVN_CMD=$MAVEN_HOME/bin/$MVN_CMD" diff --git a/pom.xml b/pom.xml index 8c641c68..ed26f5c7 100644 --- a/pom.xml +++ b/pom.xml @@ -52,14 +52,14 @@ test - tornado + io.github.beehive-lab tornado-api - 1.1.2-dev + 2.0.0 - tornado + io.github.beehive-lab tornado-runtime - 1.1.2-dev + 2.0.0 @@ -98,65 +98,167 @@ + + - - - org.apache.maven.plugins - maven-source-plugin - 3.3.0 - - - attach-sources - jar - - - + + + + release + + false + false + + + + + + org.apache.maven.plugins + maven-source-plugin + 3.3.1 + + + attach-sources + + jar-no-fork + + + + - - - org.apache.maven.plugins - maven-javadoc-plugin - 3.6.3 - - - attach-javadocs - - jar - + + + org.apache.maven.plugins + maven-javadoc-plugin + 3.6.3 - false - false + 21 + 21 + + --enable-preview + --add-modules=jdk.incubator.vector + + + --enable-preview + + false + false + none - - - + + + attach-javadocs + package + + jar + + + + - - - org.apache.maven.plugins - maven-gpg-plugin - 3.2.4 - - - sign-artifacts - verify - sign - - - + + + org.apache.maven.plugins + maven-gpg-plugin + 3.2.4 + + + sign-artifacts + verify + + sign + + + + - - - org.sonatype.central - central-publishing-maven-plugin - 0.8.0 - true - - central - - true - - - - + + + org.sonatype.central + central-publishing-maven-plugin + 0.8.0 + true + + central + true + + + + + + + + spotless + + + + com.diffplug.spotless + spotless-maven-plugin + 2.44.4 + + + origin/main + + + + + src/main/java/**/*.java + src/test/java/**/*.java + + + **/target/** + + + + 1.19.2 + + + + + + + + + + + + pom.xml + + + 4 + false + + + + + + + **/*.md + + + **/target/** + + + + + + props + + src/**/*.properties + + + **/target/** + + + + + + + + + + diff --git a/src/main/java/org/beehive/gpullama3/LlamaApp.java b/src/main/java/org/beehive/gpullama3/LlamaApp.java index 7da9b878..822a082c 100644 --- a/src/main/java/org/beehive/gpullama3/LlamaApp.java +++ b/src/main/java/org/beehive/gpullama3/LlamaApp.java @@ -1,10 +1,8 @@ package org.beehive.gpullama3; -import org.beehive.gpullama3.aot.AOT; import org.beehive.gpullama3.auxiliary.LastRunMetrics; import org.beehive.gpullama3.inference.sampler.Sampler; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.loader.ModelLoader; import java.io.IOException; diff --git a/src/main/java/org/beehive/gpullama3/aot/AOT.java b/src/main/java/org/beehive/gpullama3/aot/AOT.java deleted file mode 100644 index 7fde18ca..00000000 --- a/src/main/java/org/beehive/gpullama3/aot/AOT.java +++ /dev/null @@ -1,85 +0,0 @@ -package org.beehive.gpullama3.aot; - -import org.beehive.gpullama3.auxiliary.Timer; -import org.beehive.gpullama3.core.model.GGUF; -import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; -import org.beehive.gpullama3.model.loader.LlamaModelLoader; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.Options; -import org.beehive.gpullama3.model.format.LlamaChatFormat; -import org.beehive.gpullama3.model.llama.Llama; -import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.tokenizer.impl.LlamaTokenizer; - -import java.io.IOException; -import java.nio.channels.FileChannel; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.StandardOpenOption; -import java.util.Map; -import java.util.Objects; - -/** - * Support for AOT preloading of GGUF metadata with GraalVM's Native Image. - * - *

- * To preload a model at build time, pass {@code -Dllama.PreloadGGUF=/path/to/model.gguf} - * to the native-image builder command. At runtime, the preloaded model will be used - * iff the specified and preloaded file names (base name) match. - */ -public final class AOT { - AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF; - - static LlamaModelLoader modelLoader; - - record PartialModel(String modelFileName, Llama model, long tensorDataOffset, Map tensorInfos) { - } - - private static final PartialModel PRELOADED_GGUF = preLoadGGUF(System.getProperty("llama.PreloadGGUF")); - - private static PartialModel preLoadGGUF(String modelPath) { - if (modelPath == null || modelPath.isEmpty()) { - return null; - } - try { - Path path = Path.of(modelPath); - if (!Files.exists(path) || !Files.isRegularFile(path)) { - throw new IllegalArgumentException("Cannot pre-load model: " + path); - } - GGUF gguf = GGUF.loadModel(path); - try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) { - modelLoader = new LlamaModelLoader(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false, false); - return new PartialModel(path.getFileName().toString(), modelLoader.loadModel(), // TODO: needs proper handling for AOT - gguf.getTensorDataOffset(), gguf.getTensorInfos()); - } - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - /** - * Tries to reuse a compatible AOT preloaded model. - * The file name (base name) must match with the preloaded file name. - * No checksum/hash is checked for performance reasons. - */ - public static Model tryUsePreLoaded(Path modelPath, int contextLength) throws IOException { - AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF; - if (preLoaded == null) { - return null; // no pre-loaded model stored - } - String optionsModel = modelPath.getFileName().toString(); - String preLoadedModel = preLoaded.modelFileName(); - if (!Objects.equals(optionsModel, preLoadedModel)) { - // Preloaded and specified model file names didn't match. - return null; - } - Llama baseModel = preLoaded.model(); - try (var timer = Timer.log("Load tensors from pre-loaded model"); var fileChannel = FileChannel.open(modelPath, StandardOpenOption.READ)) { - // Load only the tensors (mmap slices). - Map tensorEntries = GGUF.loadTensors(fileChannel, preLoaded.tensorDataOffset(), preLoaded.tensorInfos()); - Weights weights = modelLoader.loadWeights(tensorEntries, baseModel.configuration()); - return new Llama(baseModel.configuration().withContextLength(contextLength), baseModel.tokenizer(), weights, new LlamaChatFormat((LlamaTokenizer) baseModel.tokenizer())); - } - } -} - diff --git a/src/main/java/org/beehive/gpullama3/core/types/Pair.java b/src/main/java/org/beehive/gpullama3/auxiliary/Pair.java similarity index 61% rename from src/main/java/org/beehive/gpullama3/core/types/Pair.java rename to src/main/java/org/beehive/gpullama3/auxiliary/Pair.java index 882d2f11..547280dd 100644 --- a/src/main/java/org/beehive/gpullama3/core/types/Pair.java +++ b/src/main/java/org/beehive/gpullama3/auxiliary/Pair.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.core.types; +package org.beehive.gpullama3.auxiliary; public record Pair(First first, Second second) { } diff --git a/src/main/java/org/beehive/gpullama3/auxiliary/Tuple2.java b/src/main/java/org/beehive/gpullama3/auxiliary/Tuple2.java index 140569bc..2a340162 100644 --- a/src/main/java/org/beehive/gpullama3/auxiliary/Tuple2.java +++ b/src/main/java/org/beehive/gpullama3/auxiliary/Tuple2.java @@ -19,20 +19,23 @@ public U getSecond() { @Override public String toString() { - return "Tuple2{" + - "first=" + first + - ", second=" + second + - '}'; + return "Tuple2{" + "first=" + first + ", second=" + second + '}'; } @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } Tuple2 tuple2 = (Tuple2) o; - if (!first.equals(tuple2.first)) return false; + if (!first.equals(tuple2.first)) { + return false; + } return second.equals(tuple2.second); } diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java index eb21701c..8104e561 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.inference; import org.beehive.gpullama3.auxiliary.Parallel; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.inference.state.Phi3State; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights; @@ -583,7 +583,7 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i final Configuration configuration = model.configuration(); final TornadoWeights weights = (TornadoWeights) model.weights(); - MemorySegment.copy(weights.tokenEmbeddingTable.getSegment(), (long) token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES); + MemorySegment.copy(weights.getTokenEmbeddingTable().asFloatArray().getSegment(), (long) token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES); return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position); } diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngine.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngine.java index e7b21cbb..7599f1b0 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngine.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngine.java @@ -5,7 +5,7 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tokenizer.impl.Tokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; diff --git a/src/main/java/org/beehive/gpullama3/inference/operation/RoPE.java b/src/main/java/org/beehive/gpullama3/inference/operation/RoPE.java index a97a9519..5ed39eda 100644 --- a/src/main/java/org/beehive/gpullama3/inference/operation/RoPE.java +++ b/src/main/java/org/beehive/gpullama3/inference/operation/RoPE.java @@ -1,6 +1,6 @@ package org.beehive.gpullama3.inference.operation; -import org.beehive.gpullama3.core.types.Pair; +import org.beehive.gpullama3.auxiliary.Pair; public final class RoPE { public static Pair precomputeFreqsCis(int contextLength, int headSize, double theta, diff --git a/src/main/java/org/beehive/gpullama3/inference/sampler/CategoricalSampler.java b/src/main/java/org/beehive/gpullama3/inference/sampler/CategoricalSampler.java index 94cd4467..b5da9d64 100644 --- a/src/main/java/org/beehive/gpullama3/inference/sampler/CategoricalSampler.java +++ b/src/main/java/org/beehive/gpullama3/inference/sampler/CategoricalSampler.java @@ -1,6 +1,6 @@ package org.beehive.gpullama3.inference.sampler; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import java.util.random.RandomGenerator; diff --git a/src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java b/src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java index f3a27c33..496d0761 100644 --- a/src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java +++ b/src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java @@ -1,21 +1,39 @@ package org.beehive.gpullama3.inference.sampler; import org.beehive.gpullama3.Options; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.tornadovm.FloatArrayUtils; +import org.beehive.gpullama3.tornadovm.utils.FloatArrayUtils; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import java.util.random.RandomGenerator; import java.util.random.RandomGeneratorFactory; /** - * Generic interface for sampling tokens from probability distributions. - * Supports both FloatTensor and FloatArray tensor implementations. + * Generic interface for sampling tokens from probability distributions. Supports both FloatTensor and FloatArray tensor implementations. */ @FunctionalInterface public interface Sampler { + /** + * Argmax implementation for FloatTensor. + */ + Sampler TENSOR_ARGMAX = tensor -> { + if (tensor instanceof FloatTensor) { + return ((FloatTensor) tensor).argmax(); + } else if (tensor instanceof FloatArray) { + return argmaxFloatArray((FloatArray) tensor); + } + throw new IllegalArgumentException("Unsupported tensor type: " + (tensor != null ? tensor.getClass().getName() : "null")); + }; + /** + * Legacy ARGMAX for backward compatibility. + * + * @deprecated Use TENSOR_ARGMAX instead + */ + @Deprecated + Sampler ARGMAX = TENSOR_ARGMAX; + /** * Creates and configures a sampler for token generation based on specified parameters. * @@ -103,42 +121,15 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp, return sampler; } - public static Sampler createSampler(Model model, Options options) { + static Sampler createSampler(Model model, Options options) { return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed()); } - /** - * Sample a token from the provided tensor. - * - * @param tensor The tensor containing probabilities/logits - * @return The selected token index - */ - int sampleToken(Object tensor); - - /** - * Argmax implementation for FloatTensor. - */ - Sampler TENSOR_ARGMAX = tensor -> { - if (tensor instanceof FloatTensor) { - return ((FloatTensor) tensor).argmax(); - } else if (tensor instanceof FloatArray) { - return argmaxFloatArray((FloatArray) tensor); - } - throw new IllegalArgumentException("Unsupported tensor type: " + - (tensor != null ? tensor.getClass().getName() : "null")); - }; - - /** - * Legacy ARGMAX for backward compatibility. - * @deprecated Use TENSOR_ARGMAX instead - */ - @Deprecated - Sampler ARGMAX = TENSOR_ARGMAX; - /** * Find the index of the maximum value in a FloatArray. * - * @param array The FloatArray to find the maximum value in + * @param array + * The FloatArray to find the maximum value in * @return The index of the maximum value */ static int argmaxFloatArray(FloatArray array) { @@ -155,4 +146,13 @@ static int argmaxFloatArray(FloatArray array) { return maxIndex; } + + /** + * Sample a token from the provided tensor. + * + * @param tensor + * The tensor containing probabilities/logits + * @return The selected token index + */ + int sampleToken(Object tensor); } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/inference/sampler/ToppSampler.java b/src/main/java/org/beehive/gpullama3/inference/sampler/ToppSampler.java index 2f52762d..fa8754d0 100644 --- a/src/main/java/org/beehive/gpullama3/inference/sampler/ToppSampler.java +++ b/src/main/java/org/beehive/gpullama3/inference/sampler/ToppSampler.java @@ -1,6 +1,6 @@ package org.beehive.gpullama3.inference.sampler; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import java.util.Comparator; diff --git a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java index fe506451..9f9fdcdb 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.inference.state; -import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.model.Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java b/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java index 1d738259..d29ba130 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.inference.state; -import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.phi3.Phi3Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java index d5623e88..da6d7046 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.inference.state; -import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java index 16837270..d6a6d087 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.inference.state; -import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; diff --git a/src/main/java/org/beehive/gpullama3/inference/state/State.java b/src/main/java/org/beehive/gpullama3/inference/state/State.java index 532e9863..01d94936 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/State.java @@ -1,6 +1,6 @@ package org.beehive.gpullama3.inference.state; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.model.Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/Weights.java b/src/main/java/org/beehive/gpullama3/inference/weights/Weights.java index 2672e606..1e753d59 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/Weights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/Weights.java @@ -1,6 +1,6 @@ package org.beehive.gpullama3.inference.weights; -import org.beehive.gpullama3.core.model.GGMLType; +import org.beehive.gpullama3.tensor.GGMLType; /** * The GPULlama3.java utilizes two distinct weight types: diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/standard/LlamaStandardWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/standard/LlamaStandardWeights.java index 27ce301a..f5401a28 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/standard/LlamaStandardWeights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/standard/LlamaStandardWeights.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.inference.weights.standard; -import org.beehive.gpullama3.core.model.GGMLType; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.standard.FloatTensor; /** * A model-specific implementation of {@link StandardWeights} for the Llama model. diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/standard/Phi3StandardWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/standard/Phi3StandardWeights.java index 5c331774..6e1c1c33 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/standard/Phi3StandardWeights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/standard/Phi3StandardWeights.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.inference.weights.standard; -import org.beehive.gpullama3.core.model.GGMLType; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.standard.FloatTensor; public class Phi3StandardWeights extends StandardWeights { diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/standard/Qwen2StandardWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/standard/Qwen2StandardWeights.java index fe401d0e..663bc158 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/standard/Qwen2StandardWeights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/standard/Qwen2StandardWeights.java @@ -1,9 +1,8 @@ package org.beehive.gpullama3.inference.weights.standard; -import org.beehive.gpullama3.core.model.GGMLType; -import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; -import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; public class Qwen2StandardWeights extends StandardWeights { // Qwen2-specific weights diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/standard/Qwen3StandardWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/standard/Qwen3StandardWeights.java index 99a4634d..861a1ebf 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/standard/Qwen3StandardWeights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/standard/Qwen3StandardWeights.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.inference.weights.standard; -import org.beehive.gpullama3.core.model.GGMLType; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.standard.FloatTensor; /** * A model-specific implementation of {@link StandardWeights} for the Qwen-3 model. diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/standard/StandardWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/standard/StandardWeights.java index e6df9c6a..abae92f8 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/standard/StandardWeights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/standard/StandardWeights.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.inference.weights.standard; -import org.beehive.gpullama3.core.model.GGMLType; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.inference.weights.Weights; /** diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/LlamaTornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/LlamaTornadoWeights.java index d8127007..98d8eb4c 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/LlamaTornadoWeights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/LlamaTornadoWeights.java @@ -1,49 +1,32 @@ package org.beehive.gpullama3.inference.weights.tornado; -import org.beehive.gpullama3.core.model.GGMLType; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.tornado.TornadoTensor; -/** - * A model-specific implementation of {@link TornadoWeights} for the Llama model. - * This class encapsulates the weights required for performing GPU-accelerated - * inference of the Llama model using TornadoVM. - * - *

Note: This weight format can also be used with the Mistral model.

- */ public class LlamaTornadoWeights extends TornadoWeights { - // @formatter:off public LlamaTornadoWeights( - FloatArray tokenEmbeddingTable, - FloatArray[] rms_att_weightLayered, - HalfFloatArray[] wqLayered, - HalfFloatArray[] wkLayered, - HalfFloatArray[] wvLayered, - HalfFloatArray[] woLayered, - FloatArray[] rms_ffn_weightLayered, - HalfFloatArray[] w1Layered, - HalfFloatArray[] w2Layered, - HalfFloatArray[] w3Layered, - FloatArray rms_final_weight_as_floatArray, - FloatArray freq_cis_realFlat, - FloatArray freq_cis_imagFlat, - HalfFloatArray wclsByteArray, + TornadoTensor tokenEmbeddingTable, + TornadoTensor[] rms_att_weightLayered, + TornadoTensor[] wqLayered, + TornadoTensor[] wkLayered, + TornadoTensor[] wvLayered, + TornadoTensor[] woLayered, + TornadoTensor[] rms_ffn_weightLayered, + TornadoTensor[] w1Layered, + TornadoTensor[] w2Layered, + TornadoTensor[] w3Layered, + TornadoTensor rms_final_weight_as_floatArray, + TornadoTensor freq_cis_realFlat, + TornadoTensor freq_cis_imagFlat, + TornadoTensor wclsByteArray, GGMLType weightType) { - // call to TornadoWeights constructor - super(tokenEmbeddingTable, - rms_att_weightLayered, - wqLayered, - wkLayered, - wvLayered, - woLayered, + super(tokenEmbeddingTable, rms_att_weightLayered, + wqLayered, wkLayered, wvLayered, woLayered, rms_ffn_weightLayered, - w1Layered, - w2Layered, - w3Layered, + w1Layered, w2Layered, w3Layered, rms_final_weight_as_floatArray, - freq_cis_realFlat, - freq_cis_imagFlat, + freq_cis_realFlat, freq_cis_imagFlat, wclsByteArray, weightType); } diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Phi3TornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Phi3TornadoWeights.java index fa6d9da4..cb1ab7e9 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Phi3TornadoWeights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Phi3TornadoWeights.java @@ -1,32 +1,31 @@ package org.beehive.gpullama3.inference.weights.tornado; -import org.beehive.gpullama3.core.model.GGMLType; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.tornado.TornadoTensor; public class Phi3TornadoWeights extends TornadoWeights { // Phi3-specific weight arrays - public HalfFloatArray[] wqkvLayered; // Combined QKV weights: (layer, op_size, dim) where op_size = dim + 2 * (n_kv_heads * head_dim) - public HalfFloatArray[] wDownLayered; // FFN down projection: (layer, dim, hidden_dim) - public HalfFloatArray[] wUpLayered; // FFN up projection: (layer, hidden_dim, dim) + public TornadoTensor[] wqkvLayered; // hf - Combined QKV weights: (layer, op_size, dim) where op_size = dim + 2 * (n_kv_heads * head_dim) + public TornadoTensor[] wDownLayered; // hf - FFN down projection: (layer, dim, hidden_dim) + public TornadoTensor[] wUpLayered; // hf - FFN up projection: (layer, hidden_dim, dim) // @formatter:off public Phi3TornadoWeights( - FloatArray tokenEmbeddingTable, - FloatArray[] rms_att_weightLayered, - HalfFloatArray[] wqkvLayered, // Combined QKV weights for Phi3 - HalfFloatArray[] woLayered, - FloatArray[] rms_ffn_weightLayered, - HalfFloatArray[] wDownLayered, // FFN down weights - HalfFloatArray[] wUpLayered, // FFN up weights - FloatArray rms_final_weight_as_floatArray, - FloatArray freq_cis_realFlat, - FloatArray freq_cis_imagFlat, - HalfFloatArray wclsByteArray, + TornadoTensor tokenEmbeddingTable, + TornadoTensor[] rms_att_weightLayered, + TornadoTensor[] wqkvLayered, // Combined QKV weights for Phi3 + TornadoTensor[] woLayered, + TornadoTensor[] rms_ffn_weightLayered, + TornadoTensor[] wDownLayered, // FFN down weights + TornadoTensor[] wUpLayered, // FFN up weights + TornadoTensor rms_final_weight_as_floatArray, + TornadoTensor freq_cis_realFlat, + TornadoTensor freq_cis_imagFlat, + TornadoTensor wclsByteArray, GGMLType weightType) { - // Call to TornadoWeights constructor with null values for unused standard weights + // Call to BaseTornadoWeights constructor with null values for unused standard weights super(tokenEmbeddingTable, rms_att_weightLayered, null, // wqLayered - not used in Phi3, using combined wqkv instead @@ -49,4 +48,4 @@ public Phi3TornadoWeights( this.wUpLayered = wUpLayered; } // @formatter:on -} \ No newline at end of file +} diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeights.java index fc7db216..6e3802d3 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeights.java @@ -1,24 +1,34 @@ package org.beehive.gpullama3.inference.weights.tornado; -import org.beehive.gpullama3.core.model.GGMLType; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.tornado.TornadoTensor; public class Qwen2TornadoWeights extends TornadoWeights { // Qwen2-specific tornado weights - public FloatArray[] q_biasLayered; - public FloatArray[] k_biasLayered; - public FloatArray[] v_biasLayered; + public TornadoTensor[] q_biasLayered; + public TornadoTensor[] k_biasLayered; + public TornadoTensor[] v_biasLayered; - public Qwen2TornadoWeights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_weightLayered, HalfFloatArray[] wqLayered, HalfFloatArray[] wkLayered, HalfFloatArray[] wvLayered, - FloatArray[] wqBiasLayered, - FloatArray[] wkBiasLayered, - FloatArray[] wvBiasLayered, - HalfFloatArray[] woLayered, FloatArray[] rms_ffn_weightLayered, HalfFloatArray[] w1Layered, - HalfFloatArray[] w2Layered, HalfFloatArray[] w3Layered, FloatArray rms_final_weight_as_floatArray, FloatArray freq_cis_realFlat, FloatArray freq_cis_imagFlat, HalfFloatArray wclsByteArray, - GGMLType weightType) { - // call to TornadoWeights constructor + // @formatter:off + public Qwen2TornadoWeights(TornadoTensor tokenEmbeddingTable, + TornadoTensor[] rms_att_weightLayered, + TornadoTensor[] wqLayered, + TornadoTensor[] wkLayered, + TornadoTensor[] wvLayered, + TornadoTensor[] q_biasLayered, + TornadoTensor[] k_biasLayered, + TornadoTensor[] v_biasLayered, + TornadoTensor[] woLayered, + TornadoTensor[] rms_ffn_weightLayered, + TornadoTensor[] w1Layered, + TornadoTensor[] w2Layered, + TornadoTensor[] w3Layered, + TornadoTensor rms_final_weight_as_floatArray, + TornadoTensor freq_cis_realFlat, + TornadoTensor freq_cis_imagFlat, + TornadoTensor wclsByteArray, + GGMLType weightType) { super(tokenEmbeddingTable, rms_att_weightLayered, wqLayered, @@ -34,9 +44,10 @@ public Qwen2TornadoWeights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_ freq_cis_imagFlat, wclsByteArray, weightType); - // init qwen2-specific fields - this.q_biasLayered = wqBiasLayered; - this.k_biasLayered = wkBiasLayered; - this.v_biasLayered = wvBiasLayered; + // + this.q_biasLayered = q_biasLayered; + this.k_biasLayered = k_biasLayered; + this.v_biasLayered = v_biasLayered; } + // @formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen3TornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen3TornadoWeights.java index 6f615d16..53a8cafd 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen3TornadoWeights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen3TornadoWeights.java @@ -1,59 +1,35 @@ package org.beehive.gpullama3.inference.weights.tornado; -import org.beehive.gpullama3.core.model.GGMLType; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.tornado.TornadoTensor; -/** - * A model-specific implementation of {@link TornadoWeights} for the Qwen3 model. - * This class encapsulates the weights required for performing GPU-accelerated - * inference of the Qwen3 model using TornadoVM. - * - *

Note: This weight format can also be used with the Mistral model.

- */ public class Qwen3TornadoWeights extends TornadoWeights { - - //attnKNorm - public FloatArray[] rms_att_KNormLayered; - //attnQNorm - public FloatArray[] rms_att_QNormLayered; + // Qwen3-specific fields + public final TornadoTensor[] rms_att_KNormLayered; + public final TornadoTensor[] rms_att_QNormLayered; // @formatter:off public Qwen3TornadoWeights( - FloatArray tokenEmbeddingTable, - FloatArray[] rms_att_weightLayered, - HalfFloatArray[] wqLayered, - HalfFloatArray[] wkLayered, - HalfFloatArray[] wvLayered, - HalfFloatArray[] woLayered, - FloatArray[] rms_att_KNormLayered, - FloatArray[] rms_att_QNormLayered, - FloatArray[] rms_ffn_weightLayered, - HalfFloatArray[] w1Layered, - HalfFloatArray[] w2Layered, - HalfFloatArray[] w3Layered, - FloatArray rms_final_weight_as_floatArray, - FloatArray freq_cis_realFlat, - FloatArray freq_cis_imagFlat, - HalfFloatArray wclsByteArray, + TornadoTensor tokenEmbeddingTable, + TornadoTensor[] rmsAttWeight, + TornadoTensor[] wq, + TornadoTensor[] wk, + TornadoTensor[] wv, + TornadoTensor[] wo, + TornadoTensor[] rms_att_KNormLayered, + TornadoTensor[] rms_att_QNormLayered, + TornadoTensor[] rmsFFNWeight, + TornadoTensor[] w1, + TornadoTensor[] w2, + TornadoTensor[] w3, + TornadoTensor rmsFinalWeight, + TornadoTensor freqCisReal, + TornadoTensor freqCisImag, + TornadoTensor wCls, GGMLType weightType) { - // call to TornadoWeights constructor - super(tokenEmbeddingTable, - rms_att_weightLayered, - wqLayered, - wkLayered, - wvLayered, - woLayered, - rms_ffn_weightLayered, - w1Layered, - w2Layered, - w3Layered, - rms_final_weight_as_floatArray, - freq_cis_realFlat, - freq_cis_imagFlat, - wclsByteArray, - weightType); - // init qwen3-specific fields + super(tokenEmbeddingTable, rmsAttWeight, wq, wk, wv, wo, + rmsFFNWeight, w1, w2, w3, rmsFinalWeight, + freqCisReal, freqCisImag, wCls, weightType); this.rms_att_KNormLayered = rms_att_KNormLayered; this.rms_att_QNormLayered = rms_att_QNormLayered; } diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/TornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/TornadoWeights.java index 8d6b7fbc..6591dc7c 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/TornadoWeights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/TornadoWeights.java @@ -1,52 +1,67 @@ package org.beehive.gpullama3.inference.weights.tornado; -import org.beehive.gpullama3.core.model.GGMLType; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.tornado.TornadoTensor; import org.beehive.gpullama3.inference.weights.Weights; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import org.beehive.gpullama3.model.loader.ModelLoader; -//@formatter:off /** - * Base class that represents the Tornado weight format used for Java-based GPU acceleration. - * This abstract class provides the foundation for defining model-specific weights in the TornadoVM. + * Base class for TornadoVM-optimized weights. + * All weight fields are TornadoTensor types (parallel to StandardWeights using FloatTensor). + *

+ * Notes: + *

    + * {@link TornadoWeights#tokenEmbeddingTable} should always be loaded as F32 see {@link ModelLoader#loadTornadoTensorAsFP32}. + * {@link TornadoWeights#rms_ffn_weightLayered} should always be loaded as F32 see {@link ModelLoader#loadTornadoTensorAsFP32}. + * {@link TornadoWeights#rms_final_weight_as_floatArray} should always be loaded as F32 see {@link ModelLoader#loadTornadoTensorAsFP32}. + *
+ *

*/ public abstract class TornadoWeights implements Weights { + // Token embedding table + public final TornadoTensor tokenEmbeddingTable; // (vocab_size, dim) - public FloatArray[] rms_att_weightLayered; // (layer, dim) rmsnorm weights - public HalfFloatArray[] wqLayered; // (layer, n_heads * head_size) - public HalfFloatArray[] wkLayered; // (layer, n_kv_heads, head_size) - public HalfFloatArray[] wvLayered; // (layer, n_kv_heads * head_size) - public HalfFloatArray[] woLayered; // (layer, n_heads * head_size, dim) - public FloatArray[] rms_ffn_weightLayered; // (layer, dim) - public HalfFloatArray[] w1Layered; // (layer, hidden_dim, dim) - public HalfFloatArray[] w2Layered; // (layer, dim, hidden_dim) - public HalfFloatArray[] w3Layered; // (layer, hidden_dim, dim) - public FloatArray rms_final_weight_as_floatArray; - public FloatArray tokenEmbeddingTable; // (vocab_size, dim) - public FloatArray freq_cis_realFlat; // (seq_len, head_size/2) - public FloatArray freq_cis_imagFlat; // (seq_len, head_size/2) - public HalfFloatArray wclsHalfFloat; + // Weights for RMSNorms + public final TornadoTensor[] rms_att_weightLayered; // (layer, dim) rmsnorm weights + + // Weights for attention + public final TornadoTensor[] wqLayered; // (layer, n_heads * head_size) + public final TornadoTensor[] wkLayered; // (layer, n_kv_heads, head_size) + public final TornadoTensor[] wvLayered; // (layer, n_kv_heads * head_size) + public final TornadoTensor[] woLayered; // (layer, n_heads * head_size, dim) + public final TornadoTensor[] rms_ffn_weightLayered; // (layer, dim) + + // Weights for FFN + public final TornadoTensor[] w1Layered; // (layer, hidden_dim, dim) + public final TornadoTensor[] w2Layered; // (layer, dim, hidden_dim) + public final TornadoTensor[] w3Layered; // (layer, hidden_dim, dim) + + // Final weights + public final TornadoTensor rms_final_weight_as_floatArray; // (dim,) + public final TornadoTensor wclsByteArray; // (vocab_size, dim) + + // RoPE frequencies (always F32) + public final TornadoTensor freq_cis_realFlat; // (seq_len, head_size/2) + public final TornadoTensor freq_cis_imagFlat; // (seq_len, head_size/2) - // (optional) classifier weights for the logits, on the last layer protected final GGMLType weightType; protected TornadoWeights( - FloatArray tokenEmbeddingTable, - FloatArray[] rms_att_weightLayered, - HalfFloatArray[] wqLayered, - HalfFloatArray[] wkLayered, - HalfFloatArray[] wvLayered, - HalfFloatArray[] woLayered, - FloatArray[] rms_ffn_weightLayered, - HalfFloatArray[] w1Layered, - HalfFloatArray[] w2Layered, - HalfFloatArray[] w3Layered, - FloatArray rms_final_weight_as_floatArray, - FloatArray freq_cis_realFlat, - FloatArray freq_cis_imagFlat, - HalfFloatArray wclsByteArray, + TornadoTensor tokenEmbeddingTable, + TornadoTensor[] rms_att_weightLayered, + TornadoTensor[] wqLayered, + TornadoTensor[] wkLayered, + TornadoTensor[] wvLayered, + TornadoTensor[] woLayered, + TornadoTensor[] rms_ffn_weightLayered, + TornadoTensor[] w1Layered, + TornadoTensor[] w2Layered, + TornadoTensor[] w3Layered, + TornadoTensor rms_final_weight_as_floatArray, + TornadoTensor freq_cis_realFlat, + TornadoTensor freq_cis_imagFlat, + TornadoTensor wclsByteArray, GGMLType weightType) { - // TornadoVM format this.tokenEmbeddingTable = tokenEmbeddingTable; this.rms_att_weightLayered = rms_att_weightLayered; this.wqLayered = wqLayered; @@ -60,14 +75,16 @@ protected TornadoWeights( this.rms_final_weight_as_floatArray = rms_final_weight_as_floatArray; this.freq_cis_realFlat = freq_cis_realFlat; this.freq_cis_imagFlat = freq_cis_imagFlat; - this.wclsHalfFloat = wclsByteArray; + this.wclsByteArray = wclsByteArray; this.weightType = weightType; } - //@formatter:on + + public TornadoTensor getTokenEmbeddingTable() { + return tokenEmbeddingTable; + } @Override public GGMLType getWeightType() { return weightType; } - } diff --git a/src/main/java/org/beehive/gpullama3/model/AbstractModel.java b/src/main/java/org/beehive/gpullama3/model/AbstractModel.java index d67d9ae5..c5ff3c6a 100644 --- a/src/main/java/org/beehive/gpullama3/model/AbstractModel.java +++ b/src/main/java/org/beehive/gpullama3/model/AbstractModel.java @@ -2,7 +2,7 @@ import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.model.format.ChatFormat; -import org.beehive.gpullama3.tokenizer.impl.Tokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; public abstract class AbstractModel implements Model { diff --git a/src/main/java/org/beehive/gpullama3/model/Model.java b/src/main/java/org/beehive/gpullama3/model/Model.java index b198713e..6defefd0 100644 --- a/src/main/java/org/beehive/gpullama3/model/Model.java +++ b/src/main/java/org/beehive/gpullama3/model/Model.java @@ -6,7 +6,7 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.model.format.ChatFormat; -import org.beehive.gpullama3.tokenizer.impl.Tokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import java.util.ArrayList; diff --git a/src/main/java/org/beehive/gpullama3/model/ModelType.java b/src/main/java/org/beehive/gpullama3/model/ModelType.java index e36533b3..ce88a69b 100644 --- a/src/main/java/org/beehive/gpullama3/model/ModelType.java +++ b/src/main/java/org/beehive/gpullama3/model/ModelType.java @@ -1,6 +1,6 @@ package org.beehive.gpullama3.model; -import org.beehive.gpullama3.core.model.GGUF; +import org.beehive.gpullama3.tensor.GGUF; import org.beehive.gpullama3.model.loader.LlamaModelLoader; import org.beehive.gpullama3.model.loader.MistralModelLoader; import org.beehive.gpullama3.model.loader.Phi3ModelLoader; @@ -16,7 +16,7 @@ *

Usage: Use {@code ModelType} to specify or retrieve the type of * large language model (LLM), such as Llama or Qwen3. This ensures clean and structured handling of model behaviors and configurations by * dispatching calls to the appropriate model loader for each - * model type.

+ * model type.

* *

Each enum value represents a distinct model type, which might be used for * conditional logic, initialization, or resource allocation within GPULlama3.java.

@@ -24,55 +24,55 @@ public enum ModelType { LLAMA_3 { @Override - public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { - return new LlamaModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel(); + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + return new LlamaModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel(); } }, MISTRAL { @Override - public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { - return new MistralModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel(); + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + return new MistralModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel(); } }, QWEN_2 { @Override - public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { - return new Qwen2ModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel(); + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + return new Qwen2ModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel(); } }, QWEN_3 { @Override - public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { - return new Qwen3ModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel(); + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + return new Qwen3ModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel(); } }, DEEPSEEK_R1_DISTILL_QWEN { @Override - public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { - return new Qwen2ModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel(); + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + return new Qwen2ModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel(); } }, PHI_3 { @Override - public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { - return new Phi3ModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel(); + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + return new Phi3ModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel(); } }, UNKNOWN { @Override - public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { throw new UnsupportedOperationException("Cannot load unknown model type"); } }; // Abstract method that each enum constant must implement - public abstract Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm); + public abstract Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm); public boolean isDeepSeekR1() { return this == DEEPSEEK_R1_DISTILL_QWEN; diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java index 7092de92..e2a166b0 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java @@ -1,9 +1,9 @@ package org.beehive.gpullama3.model.format; -import org.beehive.gpullama3.tokenizer.impl.LlamaTokenizer; -import org.beehive.gpullama3.tokenizer.impl.MistralTokenizer; -import org.beehive.gpullama3.tokenizer.impl.Phi3Tokenizer; -import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer; +import org.beehive.gpullama3.tokenizer.LlamaTokenizer; +import org.beehive.gpullama3.tokenizer.MistralTokenizer; +import org.beehive.gpullama3.tokenizer.Phi3Tokenizer; +import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer; import java.util.List; import java.util.Set; diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index bee0dcf8..80987a06 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -1,6 +1,6 @@ package org.beehive.gpullama3.model.format; -import org.beehive.gpullama3.tokenizer.impl.LlamaTokenizer; +import org.beehive.gpullama3.tokenizer.LlamaTokenizer; import java.util.ArrayList; import java.util.List; diff --git a/src/main/java/org/beehive/gpullama3/model/format/MistralChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/MistralChatFormat.java index bb7b68e0..e5680d87 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/MistralChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/MistralChatFormat.java @@ -1,6 +1,6 @@ package org.beehive.gpullama3.model.format; -import org.beehive.gpullama3.tokenizer.impl.MistralTokenizer; +import org.beehive.gpullama3.tokenizer.MistralTokenizer; import java.util.ArrayList; import java.util.Collections; diff --git a/src/main/java/org/beehive/gpullama3/model/format/Phi3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Phi3ChatFormat.java index 116b7757..19eb6739 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Phi3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Phi3ChatFormat.java @@ -1,6 +1,6 @@ package org.beehive.gpullama3.model.format; -import org.beehive.gpullama3.tokenizer.impl.Tokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; import java.util.ArrayList; import java.util.List; diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index f9c81a02..7e873237 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -1,6 +1,6 @@ package org.beehive.gpullama3.model.format; -import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer; +import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer; import java.util.*; diff --git a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java index ede3e3ea..8c69cb40 100644 --- a/src/main/java/org/beehive/gpullama3/model/llama/Llama.java +++ b/src/main/java/org/beehive/gpullama3/model/llama/Llama.java @@ -9,8 +9,8 @@ import org.beehive.gpullama3.model.AbstractModel; import org.beehive.gpullama3.model.ModelType; import org.beehive.gpullama3.model.format.ChatFormat; -import org.beehive.gpullama3.tokenizer.impl.LlamaTokenizer; -import org.beehive.gpullama3.tokenizer.impl.Tokenizer; +import org.beehive.gpullama3.tokenizer.LlamaTokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import java.util.List; diff --git a/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java new file mode 100644 index 00000000..14ffc968 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java @@ -0,0 +1,161 @@ +package org.beehive.gpullama3.model.loader; + +import org.beehive.gpullama3.tensor.GGUF; +import org.beehive.gpullama3.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.auxiliary.Pair; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tokenizer.Vocabulary; + +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.util.Map; + +/** + * Abstract base class for model loaders using Template Method pattern. Provides common loading flow with extension points for model-specific logic. + * + * @param The specific Model type to load + * @param The specific Configuration type for the model + */ +public abstract class AbstractModelLoader { + + protected final FileChannel fileChannel; + protected final GGUF gguf; + protected final int contextLength; + protected final boolean useTornadovm; + + protected Vocabulary vocabulary; + + protected AbstractModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + this.fileChannel = fileChannel; + this.gguf = gguf; + this.contextLength = contextLength; + this.useTornadovm = useTornadovm; + } + + /** + * Template method that defines the model loading workflow. Subclasses should not override this method. + * + * @return The loaded model instance + */ + public final M loadModel() { + try { + Map metadata = gguf.getMetadata(); + + // Step 1: Load vocabulary + this.vocabulary = loadVocabulary(metadata); + + // Step 2: Create tokenizer + Tokenizer tokenizer = createTokenizer(metadata, vocabulary); + + // Step 3: Create configuration + C config = createConfiguration(metadata); + + // Step 4: Load tensor entries + Map tensorEntries; + if (useTornadovm) { + tensorEntries = GGUF.loadTensorsTornado(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); + } else { + tensorEntries = GGUF.loadTensorsStandard(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); + } + + // Step 4: Load weights + Weights weights = loadWeights(tensorEntries, config); + + // Step 5: Create and return model instance + return createModel(config, tokenizer, weights); + + } catch (IOException e) { + throw new ModelLoadException("Failed to load model", e); + } + } + + /** + * Load the vocabulary from GGUF metadata. Model-specific implementations should override this method. + * + * @param metadata The GGUF metadata map + * @return The loaded Vocabulary + */ + protected abstract Vocabulary loadVocabulary(Map metadata); + + /** + * Create a tokenizer instance for this model. + * + * @param metadata The GGUF metadata map + * @param vocabulary The loaded vocabulary + * @return The tokenizer instance + */ + protected abstract Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary); + + /** + * Create a configuration instance from GGUF metadata. + * + * @param metadata The GGUF metadata map + * @return The configuration instance + */ + protected abstract C createConfiguration(Map metadata); + + /** + * Load model weights from tensor entries. Default implementation handles common weight loading logic. + * + * @param tensorEntries Map of tensor names to tensor entries + * @param config The model configuration + * @return The loaded weights + */ + public Weights loadWeights(Map tensorEntries, C config) { + // Precompute RoPE frequencies + Pair ropeFreqs = precomputeRopeFrequencies(config); + + // Get token embeddings and output weights + GGMLTensorEntry tokenEmbeddings = getTokenEmbeddings(tensorEntries); + GGMLTensorEntry outputWeight = getOutputWeight(tensorEntries, tokenEmbeddings); + + // Delegate to specific implementation + if (useTornadovm) { + return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + } else { + return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + } + } + + /** + * Create the final model instance. + * + * @param config The model configuration + * @param tokenizer The tokenizer + * @param weights The loaded weights + * @return The model instance + */ + protected abstract M createModel(C config, Tokenizer tokenizer, Weights weights); + + /** + * Precompute RoPE frequencies for this model. Default implementation can be overridden for custom RoPE configurations. + */ + protected abstract Pair precomputeRopeFrequencies(C config); + + /** + * Get token embeddings tensor entry. Default implementation can be overridden for different tensor naming. + */ + protected GGMLTensorEntry getTokenEmbeddings(Map tensorEntries) { + return tensorEntries.get("token_embd.weight"); + } + + /** + * Get output weight tensor entry. Default implementation falls back to token embeddings if output.weight not found. + */ + protected GGMLTensorEntry getOutputWeight(Map tensorEntries, GGMLTensorEntry tokenEmbeddings) { + return tensorEntries.getOrDefault("output.weight", tokenEmbeddings); + } + + /** + * Create standard (CPU) weights. + */ + protected abstract Weights createStandardWeights(Map tensorEntries, C config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight); + + /** + * Create TornadoVM (GPU) weights. + */ + protected abstract Weights createTornadoVMWeights(Map tensorEntries, C config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight); +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java index 79f35c92..069704a7 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java @@ -1,60 +1,141 @@ package org.beehive.gpullama3.model.loader; -import org.beehive.gpullama3.auxiliary.Timer; -import org.beehive.gpullama3.core.model.GGUF; -import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.GGUF; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor; +import org.beehive.gpullama3.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.auxiliary.Pair; +import org.beehive.gpullama3.inference.operation.RoPE; import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.standard.LlamaStandardWeights; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.format.ChatFormat; import org.beehive.gpullama3.model.llama.Llama; import org.beehive.gpullama3.model.llama.LlamaConfiguration; -import org.beehive.gpullama3.tokenizer.impl.LlamaTokenizer; -import org.beehive.gpullama3.tokenizer.impl.Tokenizer; -import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary; +import org.beehive.gpullama3.tokenizer.LlamaTokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tokenizer.Vocabulary; + +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + -import java.io.IOException; import java.nio.channels.FileChannel; import java.util.Map; -public class LlamaModelLoader extends ModelLoader { +import static org.beehive.gpullama3.model.loader.ModelLoader.*; + +public class LlamaModelLoader extends AbstractModelLoader { + + public LlamaModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + super(fileChannel, gguf, contextLength, useTornadovm); + } - public LlamaModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadoVM) { - super(fileChannel, gguf, contextLength, loadWeights, useTornadoVM); + @Override + protected Vocabulary loadVocabulary(Map metadata) { + return Vocabulary.loadLlamaVocabulary(metadata); + } + + @Override + protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { + return new LlamaTokenizer(metadata, vocabulary); } // @formatter:off @Override - public Llama loadModel() { - try { - Map metadata = gguf.getMetadata(); - - Vocabulary vocabulary = Vocabulary.loadLlamaVocabulary(metadata); - Tokenizer tokenizer = new LlamaTokenizer(metadata, vocabulary); - - LlamaConfiguration config = new LlamaConfiguration( - (int) metadata.get("llama.embedding_length"), - (int) metadata.get("llama.feed_forward_length"), - (int) metadata.get("llama.block_count"), - (int) metadata.get("llama.attention.head_count"), - - metadata.containsKey("llama.attention.head_count_kv") ? - (int) metadata.get("llama.attention.head_count_kv") : - (int) metadata.get("llama.attention.head_count"), - - vocabulary.size(), - (int) metadata.get("llama.context_length"), - (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), - (float) metadata.getOrDefault("llama.rope.freq_base", 10000f) - ).withContextLength(contextLength); - - Weights weights = null; - if (loadWeights) { - Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); - weights = loadWeights(tensorEntries, config); - } - return new Llama(config, tokenizer, weights, ChatFormat.create(tokenizer, null)); - } catch (IOException e) { - throw new RuntimeException(e); + protected LlamaConfiguration createConfiguration(Map metadata) { + int vocabSize = metadata.containsKey("llama.vocab_size") ? (int) metadata.get("llama.vocab_size") : (int) metadata.get("tokenizer.ggml.tokens.length"); + + return new LlamaConfiguration( + (int) metadata.get("llama.embedding_length"), + (int) metadata.get("llama.feed_forward_length"), + (int) metadata.get("llama.block_count"), + (int) metadata.get("llama.attention.head_count"), + metadata.containsKey("llama.attention.head_count_kv") ? + (int) metadata.get("llama.attention.head_count_kv") + : (int) metadata.get("llama.attention.head_count"), + vocabSize, + (int) metadata.get("llama.context_length"), + (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), + (float) metadata.getOrDefault("llama.rope.freq_base", 10000f)).withContextLength(contextLength); + } + // @formatter:on + + @Override + protected Pair precomputeRopeFrequencies(LlamaConfiguration config) { + return RoPE.precomputeFreqsCis(config.contextLength(), config.dim() / config.numberOfHeads(), config.ropeTheta(), false, 1.0f, 1.0f, 1.0f, config.contextLength()); + } + + @Override + protected Llama createModel(LlamaConfiguration config, Tokenizer tokenizer, Weights weights) { + return new Llama(config, tokenizer, weights, ChatFormat.create(tokenizer, null)); + } + + // @formatter:off + @Override + protected Weights createStandardWeights(Map tensorEntries, LlamaConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + + final int nl = config.numberOfLayers(); + + return new LlamaStandardWeights( + loadTensor(tokenEmbeddings), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadTensor(tensorEntries.get("output_norm.weight")), + new ArrayFloatTensor(ropeFreqs.first()), + new ArrayFloatTensor(ropeFreqs.second()), + loadTensor(outputWeight), + outputWeight.ggmlType()); + } + // @formatter:on + + // @formatter:off + @Override + protected Weights createTornadoVMWeights(Map tensorEntries, + LlamaConfiguration config, + Pair ropeFreqs, + GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + GGMLType ggmlType = outputWeight.ggmlType(); + + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { + System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")"); } + + // Validate supported types + if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { + throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); + } + + final int nl = config.numberOfLayers(); + + // Load all tensors uniformly as TornadoTensor hierarchy + return new LlamaTornadoWeights( + loadTornadoTensorAsFP32(tokenEmbeddings), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32 + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // fp32 + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadTornadoTensor(tensorEntries.get("output_norm.weight")), // fp32 + new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.first())), + new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.second())), + loadTornadoTensor(outputWeight), + ggmlType + ); } // @formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java index efe64234..25c493db 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java @@ -1,66 +1,151 @@ package org.beehive.gpullama3.model.loader; -import org.beehive.gpullama3.auxiliary.Timer; -import org.beehive.gpullama3.core.model.GGUF; -import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.GGUF; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor; +import org.beehive.gpullama3.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.auxiliary.Pair; +import org.beehive.gpullama3.inference.operation.RoPE; import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.standard.LlamaStandardWeights; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.format.ChatFormat; import org.beehive.gpullama3.model.mistral.Mistral; import org.beehive.gpullama3.model.mistral.MistralConfiguration; -import org.beehive.gpullama3.tokenizer.impl.MistralTokenizer; -import org.beehive.gpullama3.tokenizer.impl.Tokenizer; -import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary; +import org.beehive.gpullama3.tokenizer.MistralTokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tokenizer.Vocabulary; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import java.io.IOException; import java.nio.channels.FileChannel; import java.util.Map; -public class MistralModelLoader extends ModelLoader { +import static org.beehive.gpullama3.model.loader.ModelLoader.*; - public MistralModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { - super(fileChannel, gguf, contextLength, loadWeights, useTornadovm); +public class MistralModelLoader extends AbstractModelLoader { + + public MistralModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + super(fileChannel, gguf, contextLength, useTornadovm); + } + + @Override + protected Vocabulary loadVocabulary(Map metadata) { + return Vocabulary.loadMistralVocabulary(metadata); + } + + @Override + protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { + return new MistralTokenizer(metadata, vocabulary); + } + + // @formatter:off + @Override + protected MistralConfiguration createConfiguration(Map metadata) { + int modelContextLength = (int) metadata.get("llama.context_length"); + int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength; + + // Get vocabulary size from metadata + int vocabSize = metadata.containsKey("llama.vocab_size") ? (int) metadata.get("llama.vocab_size") : (int) metadata.get("tokenizer.ggml.tokens.length"); + + return new MistralConfiguration( + (int) metadata.get("llama.embedding_length"), + (int) metadata.get("llama.feed_forward_length"), + (int) metadata.get("llama.block_count"), + (int) metadata.get("llama.attention.head_count"), + metadata.containsKey("llama.attention.head_count_kv") ? + (int) metadata.get("llama.attention.head_count_kv") + : (int) metadata.get("llama.attention.head_count"), + vocabSize, + finalContextLength, + false, + (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), + (float) metadata.getOrDefault("llama.rope.freq_base", 10000f) + ); + } + // @formatter:on + + // @formatter:off + @Override + protected Pair precomputeRopeFrequencies(MistralConfiguration config) { + return RoPE.precomputeFreqsCis( + config.contextLength(), + config.dim() / config.numberOfHeads(), + config.ropeTheta(), + false, + 1.0f, + 1.0f, + 1.0f, + config.contextLength() + ); + } + // @formatter:on + + @Override + protected Mistral createModel(MistralConfiguration config, Tokenizer tokenizer, Weights weights) { + return new Mistral(config, tokenizer, weights, ChatFormat.create(tokenizer, null)); } // @formatter:off @Override - public Mistral loadModel() { - try { - Map metadata = gguf.getMetadata(); - - Vocabulary vocabulary = Vocabulary.loadMistralVocabulary(metadata); - Tokenizer tokenizer = new MistralTokenizer(metadata, vocabulary); - - int modelContextLength = (int) metadata.get("llama.context_length"); - if (contextLength < 0 || modelContextLength < contextLength) { - contextLength = modelContextLength; - } - - MistralConfiguration config = new MistralConfiguration( - (int) metadata.get("llama.embedding_length"), - (int) metadata.get("llama.feed_forward_length"), - (int) metadata.get("llama.block_count"), - (int) metadata.get("llama.attention.head_count"), - - metadata.containsKey("llama.attention.head_count_kv") - ? (int) metadata.get("llama.attention.head_count_kv") - : (int) metadata.get("llama.attention.head_count"), - - vocabulary.size(), - contextLength, - false, - (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), - (float) metadata.getOrDefault("llama.rope.freq_base", 10000f) - ); - - Weights weights = null; - if (loadWeights) { - Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); - weights = loadWeights(tensorEntries, config); - } - return new Mistral(config, tokenizer, weights, ChatFormat.create(tokenizer, null)); - } catch (IOException e) { - throw new RuntimeException(e); + protected Weights createStandardWeights(Map tensorEntries, MistralConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { + + final int nl = config.numberOfLayers(); + + return new LlamaStandardWeights( + loadTensor(tokenEmbeddings), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadTensor(tensorEntries.get("output_norm.weight")), + new ArrayFloatTensor(ropeFreqs.first()), + new ArrayFloatTensor(ropeFreqs.second()), + loadTensor(outputWeight), + outputWeight.ggmlType()); + } + // @formatter:off + + // @formatter:off + @Override + protected Weights createTornadoVMWeights(Map tensorEntries, MistralConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { + GGMLType ggmlType = outputWeight.ggmlType(); + + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { + System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")"); } + + // Validate supported types + if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { + throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); + } + + final int nl = config.numberOfLayers(); + + // Load all tensors uniformly as TornadoTensor hierarchy + return new LlamaTornadoWeights( + loadTornadoTensorAsFP32(tokenEmbeddings), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32 + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // fp32 + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadTornadoTensor(tensorEntries.get("output_norm.weight")), // fp32 + new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.first())), + new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.second())), + loadTornadoTensor(outputWeight), + ggmlType + ); } // @formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoadException.java b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoadException.java new file mode 100644 index 00000000..f09ec56b --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoadException.java @@ -0,0 +1,15 @@ +package org.beehive.gpullama3.model.loader; + +/** + * Exception thrown when model loading fails. + */ +public class ModelLoadException extends RuntimeException { + + public ModelLoadException(String message) { + super(message); + } + + public ModelLoadException(String message, Throwable cause) { + super(message, cause); + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java index 270195c6..b763c4b7 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java @@ -1,43 +1,33 @@ package org.beehive.gpullama3.model.loader; import org.beehive.gpullama3.Options; -import org.beehive.gpullama3.aot.AOT; -import org.beehive.gpullama3.core.model.GGMLType; -import org.beehive.gpullama3.core.model.GGUF; -import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; -import org.beehive.gpullama3.core.model.tensor.F16FloatTensor; -import org.beehive.gpullama3.core.model.tensor.F32FloatTensor; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; -import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; -import org.beehive.gpullama3.core.model.tensor.Q4_0FloatTensor; -import org.beehive.gpullama3.core.model.tensor.Q8_0FloatTensor; -import org.beehive.gpullama3.core.types.Pair; -import org.beehive.gpullama3.inference.operation.RoPE; -import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.inference.weights.standard.LlamaStandardWeights; -import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.GGUF; +import org.beehive.gpullama3.tensor.*; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.ModelType; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import org.beehive.gpullama3.tensor.standard.*; +import org.beehive.gpullama3.tensor.tornado.FP16TornadoTensor; +import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor; +import org.beehive.gpullama3.tensor.tornado.Q8_0TornadoTensor; +import org.beehive.gpullama3.tensor.tornado.TornadoTensor; import uk.ac.manchester.tornado.api.types.HalfFloat; -import uk.ac.manchester.tornado.api.types.arrays.ByteArray; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.*; import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.nio.channels.FileChannel; import java.nio.file.Path; -import java.nio.file.StandardOpenOption; import java.util.Map; +import java.util.Set; import java.util.function.IntFunction; +import java.util.stream.Collectors; public abstract class ModelLoader { - public static final boolean USE_AOT = Boolean.parseBoolean(System.getProperty("llama.AOT", "false")); // Use Ahead-of-Time compilation - protected FileChannel fileChannel; protected GGUF gguf; protected int contextLength; @@ -54,8 +44,6 @@ public ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolea private static ModelType detectModelType(Map metadata) { String name = (String) metadata.get("general.name"); - String tokenizerModel = (String) metadata.get("tokenizer.ggml.model"); - Integer vocabSize = (Integer) metadata.get("llama.vocab_size"); // Check by name first if (name != null) { @@ -70,10 +58,9 @@ private static ModelType detectModelType(Map metadata) { return ModelType.QWEN_3; } else if (lowerName.contains("deepseek r1 distill")) { return ModelType.DEEPSEEK_R1_DISTILL_QWEN; - } else if (lowerName.contains("phi3")) { + } else if (lowerName.contains("phi3") || lowerName.contains("phi-3")) { return ModelType.PHI_3; } - } return ModelType.UNKNOWN; @@ -81,50 +68,110 @@ private static ModelType detectModelType(Map metadata) { /** * Loads the language model based on the given options. - *

- * If Ahead-of-Time (AOT) mode is enabled, attempts to use a pre-loaded compiled model. Otherwise, loads the model from the specified path using the model loader. - *

* - * @param options - * the parsed CLI options containing model path and max token limit + *

If Ahead-of-Time (AOT) mode is enabled, attempts to use a pre-loaded compiled model. + * Otherwise, loads the model from the specified path using the model loader. + * + * @param options the parsed CLI options containing model path and max token limit * @return the loaded {@link Model} instance - * @throws IOException - * if the model fails to load - * @throws IllegalStateException - * if AOT loading is enabled but the preloaded model is unavailable + * @throws IOException if the model fails to load + * @throws IllegalStateException if AOT loading is enabled but the preloaded model is unavailable */ public static Model loadModel(Options options) throws IOException { - if (USE_AOT) { - Model model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens()); - if (model == null) { - throw new IllegalStateException("Failed to load precompiled AOT model."); - } - return model; - } - return ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true, options.useTornadovm()); - } + Path ggufPath = options.modelPath(); + int contextLength = options.maxTokens(); + boolean useTornadovm = options.useTornadovm(); - public static Model loadModel(Path ggufPath, int contextLength, boolean loadWeights, boolean useTornadovm) throws IOException { // initial load of metadata from gguf file - GGUF gguf = GGUF.loadModel(ggufPath); - FileChannel fileChannel = FileChannel.open(ggufPath, StandardOpenOption.READ); + GGUF gguf = GGUF.loadGGUFMetadata(ggufPath); // detect model type ModelType modelType = detectModelType(gguf.getMetadata()); // model type-specific load - return modelType.loadModel(fileChannel, gguf, contextLength, loadWeights, useTornadovm); + return modelType.loadModel(gguf.getFileChannel(), gguf, contextLength, useTornadovm); } - public static FloatTensor loadQuantized(GGMLTensorEntry entry) { + /** + * Dispatcher method for loading a standard (non-tornado) tensor based on GGML type. + * Used in CPU-path. + */ + public static FloatTensor loadTensor(GGMLTensorEntry entry) { GGMLType ggmlType = entry.ggmlType(); return switch (ggmlType) { - case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); + case F32 -> new FP32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); case Q4_0 -> new Q4_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); - case F16 -> new F16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); + case F16 -> new FP16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); default -> throw new UnsupportedOperationException("Quantization format " + ggmlType); }; } + /** + * Dispatcher method for loading a standard tensor array based on type. + * Used in CPU-path. + */ + public static FloatTensor[] loadArrayOfTensors(int size, IntFunction getTensorEntry) { + FloatTensor[] array = new FloatTensor[size]; + for (int i = 0; i < size; i++) { + array[i] = loadTensor(getTensorEntry.apply(i)); + } + return array; + } + + /** + * Dispatcher method for loading a TornadoVM-compatible tensor based on GGML type. + * Used in GPU-path. + */ + public static TornadoTensor loadTornadoTensor(GGMLTensorEntry entry) { + GGMLType ggmlType = entry.ggmlType(); + int size = FloatTensor.numberOfElements(entry.shape()); + return switch (ggmlType) { + case F32 -> FP32TornadoTensor.fromTornadoMemorySegment(entry.memorySegment()); + case F16 -> FP16TornadoTensor.fromTornadoMemorySegment(entry.memorySegment()); + case Q8_0 -> Q8_0TornadoTensor.createAsQ8_0(entry); + case Q4_0 -> throw new UnsupportedOperationException("Q4 format not supported yet"); + default -> throw new UnsupportedOperationException("Quantization format " + ggmlType); + }; + } + + /** + * Dispatcher method for loading a TornadoVM tensor array based on type. + * Used in GPU-path. + */ + public static TornadoTensor[] loadArrayOfTornadoTensors(int size, IntFunction getTensorEntry) { + TornadoTensor[] array = new TornadoTensor[size]; + for (int i = 0; i < size; i++) { + array[i] = loadTornadoTensor(getTensorEntry.apply(i)); + } + return array; + } + + /** + * Load a tensor and manually convert to FP32 (FloatArray). + * Used for embeddings that currently are treated as FP32. + * TODO: it is ultra-slow and should be removed + */ + public static TornadoTensor loadTornadoTensorAsFP32(GGMLTensorEntry entry) { + TornadoTensor tensor = loadTornadoTensor(entry); + return switch (tensor.type()) { + case F32 -> tensor; + case F16 -> { + HalfFloatArray tensorHFA = tensor.asHalfFloatArray(); + int numOfElements = tensorHFA.getSize(); + FloatArray tensorFA = new FloatArray(numOfElements); + for (int i = 0; i < numOfElements; i++) { + tensorFA.set(i, tensorHFA.get(i).getFloat32()); + } + yield new FP32TornadoTensor(tensorFA); + } + case Q8_0 -> Q8_0TornadoTensor.createAsFP32(entry); + default -> { + throw new UnsupportedOperationException("Unsupported tensor type: " + tensor.type()); + } + }; + } + + // Helper methods + public static FloatArray[] loadArrayAsFloatArray(int size, IntFunction getTensorEntry) { FloatArray[] array = new FloatArray[size]; for (int i = 0; i < size; i++) { @@ -132,7 +179,6 @@ public static FloatArray[] loadArrayAsFloatArray(int size, IntFunction getTensorEntry) { HalfFloatArray[] array = new HalfFloatArray[size]; @@ -142,7 +188,13 @@ public static HalfFloatArray[] loadArrayAsHalfFloatArray(int size, IntFunction getTensorEntry) { + Q8_0TornadoTensor[] array = new Q8_0TornadoTensor[size]; + for (int i = 0; i < size; i++) { + array[i] = Q8_0TornadoTensor.createAsQ8_0(getTensorEntry.apply(i)); + } + return array; + } public static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) { if (tensorEntry.ggmlType() == GGMLType.F32) { @@ -152,7 +204,6 @@ public static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) { throw new UnsupportedOperationException("Conversion to FloatArray from " + tensorEntry.ggmlType()); } } - //@formatter:on public static FloatArray[] loadArrayAsFloatArrayFromBuffer(int size, IntFunction getTensorEntry) { FloatArray[] array = new FloatArray[size]; @@ -163,7 +214,7 @@ public static FloatArray[] loadArrayAsFloatArrayFromBuffer(int size, IntFunction } public static ByteArray createByteArrayFromTensor(GGMLTensorEntry entry) { - FloatTensor tensor = loadQuantized(entry); + FloatTensor tensor = loadTensor(entry); return ByteArray.fromSegment(tensor.asMemorySegment()); } @@ -178,7 +229,7 @@ public static FloatArray loadTensorAsFloatArray(GGMLTensorEntry entry) { return array; } else { // For quantized formats, we need to load through FloatTensor - FloatTensor tensor = loadQuantized(entry); + FloatTensor tensor = loadTensor(entry); FloatArray array = new FloatArray(tensor.size()); for (int i = 0; i < tensor.size(); i++) { array.set(i, tensor.getFloat(i)); @@ -193,7 +244,7 @@ public static HalfFloatArray loadTensorAsHalfFloatArray(GGMLTensorEntry entry) { return null; } else { // For quantized formats, we need to load through FloatTensor - FloatTensor tensor = loadQuantized(entry); + FloatTensor tensor = loadTensor(entry); HalfFloatArray array = new HalfFloatArray(tensor.size()); for (int i = 0; i < tensor.size(); i++) { HalfFloat x = new HalfFloat(tensor.getFloat(i)); @@ -203,14 +254,6 @@ public static HalfFloatArray loadTensorAsHalfFloatArray(GGMLTensorEntry entry) { } } - public static FloatTensor[] loadArrayOfQuantized(int size, IntFunction getTensorEntry) { - FloatTensor[] array = new FloatTensor[size]; - for (int i = 0; i < size; i++) { - array[i] = loadQuantized(getTensorEntry.apply(i)); - } - return array; - } - public static FloatBuffer[] loadArrayOfFloatBuffer(int size, IntFunction getTensorEntry) { FloatBuffer[] array = new FloatBuffer[size]; for (int i = 0; i < size; i++) { @@ -226,95 +269,4 @@ public static FloatBuffer toFloatBuffer(GGMLTensorEntry tensorEntry) { default -> throw new UnsupportedOperationException("Conversion to " + ggmlType); }; } - - public abstract Model loadModel(); - - //@formatter:off - public Weights loadWeights(Map tensorEntries, Configuration config) { - boolean ropeScaling = tensorEntries.containsKey("rope_freqs"); - RopeConfig ropeConfig = new RopeConfig(8.0f, // scaleFactor - 1.0f, // loFreqFactor - 3.0f, // hiFreqFactor - 8192 // oldContextLength - ); - - Pair ropeFreqs = RoPE.precomputeFreqsCis( - config.contextLength(), // Maximum sequence length the model can process - config.headSize(), // Dimension of each attention head - config.ropeTheta(), // Base frequency parameter (typically 10000.0) - ropeScaling, // Whether to apply frequency scaling (determined by model type) - ropeConfig.scaleFactor, // Scale factor for extending context length (NTK-aware scaling) - ropeConfig.loFreqFactor, // Low frequency scaling factor for better long-range dependencies - ropeConfig.hiFreqFactor, // High frequency scaling factor for preserving local precision - ropeConfig.oldContextLength // Original context length the model was trained with - ); - - GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight"); - GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings); - - if (useTornadovm) { - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")"); - } - return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } else { - return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } - } - - public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { - return new LlamaTornadoWeights( - // Load directly to TornadoVM format - loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), - loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), - FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), loadTensorAsHalfFloatArray(outputWeight), outputWeight.ggmlType()) { - }; - } - - /** - * Creates weights in standard format only - */ - public Weights createStandardWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { - return new LlamaStandardWeights( - loadQuantized(tokenEmbeddings), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), - loadQuantized(tensorEntries.get("output_norm.weight")), - new ArrayFloatTensor(ropeFreqs.first()), - new ArrayFloatTensor(ropeFreqs.second()), - loadQuantized(outputWeight), - outputWeight.ggmlType()); - } - - // Helper class to encapsulate RoPE configuration parameters - private static class RopeConfig { - final float scaleFactor; - final float loFreqFactor; - final float hiFreqFactor; - final int oldContextLength; - - RopeConfig(float scaleFactor, float loFreqFactor, float hiFreqFactor, int oldContextLength) { - this.scaleFactor = scaleFactor; - this.loFreqFactor = loFreqFactor; - this.hiFreqFactor = hiFreqFactor; - this.oldContextLength = oldContextLength; - } - } - } diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java index d6b431c5..f32249ed 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java @@ -1,161 +1,157 @@ package org.beehive.gpullama3.model.loader; -import org.beehive.gpullama3.LlamaApp; -import org.beehive.gpullama3.Options; -import org.beehive.gpullama3.auxiliary.Timer; -import org.beehive.gpullama3.core.model.GGMLType; -import org.beehive.gpullama3.core.model.GGUF; -import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; -import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; -import org.beehive.gpullama3.core.types.Pair; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.GGUF; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor; +import org.beehive.gpullama3.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.auxiliary.Pair; import org.beehive.gpullama3.inference.operation.RoPE; import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights; import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; -import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.format.ChatFormat; import org.beehive.gpullama3.model.phi3.Phi3; import org.beehive.gpullama3.model.phi3.Phi3Configuration; -import org.beehive.gpullama3.tokenizer.impl.Phi3Tokenizer; -import org.beehive.gpullama3.tokenizer.impl.Tokenizer; -import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary; +import org.beehive.gpullama3.tokenizer.Phi3Tokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tokenizer.Vocabulary; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import java.io.IOException; import java.nio.channels.FileChannel; import java.util.Map; -public class Phi3ModelLoader extends ModelLoader { - public Phi3ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { - super(fileChannel, gguf, contextLength, loadWeights, useTornadovm); +import static org.beehive.gpullama3.model.loader.ModelLoader.*; + +public class Phi3ModelLoader extends AbstractModelLoader { + private int modelContextLength; + + public Phi3ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + super(fileChannel, gguf, contextLength, useTornadovm); } - // @formatter:off @Override - public Phi3 loadModel() { - try { - Map metadata = gguf.getMetadata(); - final String modelPrefix = "phi3."; + protected Vocabulary loadVocabulary(Map metadata) { + return Vocabulary.loadPhi3Vocabulary(metadata); + } - Vocabulary vocabulary = Vocabulary.loadPhi3Vocabulary(metadata); + @Override + protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { Tokenizer tokenizer = new Phi3Tokenizer(metadata, vocabulary); - - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Tokenizer: " + tokenizer.getClass().getSimpleName()); - } - - int modelContextLength = (int) metadata.get(modelPrefix + "context_length"); - if (contextLength < 0 || modelContextLength < contextLength) { - contextLength = modelContextLength; - } - - Phi3Configuration config = new Phi3Configuration( - (int) metadata.get(modelPrefix + "embedding_length"), // dim - (int) metadata.get(modelPrefix + "feed_forward_length"), // hidden_dim - (int) metadata.get(modelPrefix + "block_count"), // n_layers - (int) metadata.get(modelPrefix + "attention.head_count"), // n_heads - - metadata.containsKey(modelPrefix + "attention.head_count_kv") - ? (int) metadata.get(modelPrefix + "attention.head_count_kv") - : (int) metadata.get(modelPrefix + "attention.head_count"), // n_kv_heads - - vocabulary.size(), // vocab_size - contextLength, // context_length (user-specified, not model) - (float) metadata.getOrDefault(modelPrefix + "attention.layer_norm_rms_epsilon", 1e-5f), // rms_norm_eps - (float) metadata.getOrDefault(modelPrefix + "rope.freq_base", 10000f) // rope_theta - ); - - Weights weights = null; - if (loadWeights) { - Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); - weights = loadWeights(tensorEntries, config, modelContextLength); - } - - // Phi3 chat tokens - ChatFormat.ChatTokens chatTokens = new ChatFormat.ChatTokens( - "<|system|>", "<|end|>", "<|user|>", "<|end|>", "<|assistant|>" - ); - - return new Phi3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); - } catch (IOException e) { - throw new RuntimeException(e); + System.out.println("Tokenizer: " + tokenizer.getClass().getSimpleName()); + return tokenizer; } + return new Phi3Tokenizer(metadata, vocabulary); } - // @formatter:on // @formatter:off - private Weights loadWeights(Map tensorEntries, Configuration config, int modelContextLength) { + @Override + protected Phi3Configuration createConfiguration(Map metadata) { + final String modelPrefix = "phi3."; + + var config = new Phi3Configuration( + (int) metadata.get(modelPrefix + "embedding_length"), // dim + (int) metadata.get(modelPrefix + "feed_forward_length"), // hidden_dim + (int) metadata.get(modelPrefix + "block_count"), // n_layers + (int) metadata.get(modelPrefix + "attention.head_count"), // n_heads + + metadata.containsKey(modelPrefix + "attention.head_count_kv") + ? (int) metadata.get(modelPrefix + "attention.head_count_kv") + : (int) metadata.get(modelPrefix + "attention.head_count"), // n_kv_heads + + vocabulary.size(), // vocab_size + contextLength, // context_length (user-specified, not model) + (float) metadata.getOrDefault(modelPrefix + "attention.layer_norm_rms_epsilon", 1e-5f), // rms_norm_eps + (float) metadata.getOrDefault(modelPrefix + "rope.freq_base", 10000f) // rope_theta + ); + return config; + } + // @formatter:off + + // @formatter:off + @Override + protected Pair precomputeRopeFrequencies(Phi3Configuration config) { // Calculate head size from dim and numberOfHeads int headSize = config.dim() / config.numberOfHeads(); - Pair ropeFreqs = RoPE.precomputeFreqsCis( - modelContextLength, // Use model context length for RoPE precomputation - headSize, // Calculated head size + return RoPE.precomputeFreqsCis( + modelContextLength, // Use model context length for RoPE precomputation + headSize, // Calculated head size config.ropeTheta(), - false, // Phi3 uses standard RoPE, not neox-style based on reference - 8, 1, 3, 8192 // Additional RoPE parameters from reference + false, // Phi3 uses standard RoPE, not neox-style based on reference + 8, + 1, + 3, + 8192 // Additional RoPE parameters from reference ); + } + // @formatter:off - GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight"); - GGMLTensorEntry outputWeight = tensorEntries.get("output.weight"); // Phi3 always has separate output weight + @Override + protected Phi3 createModel(Phi3Configuration config, Tokenizer tokenizer, Weights weights) { + // Phi3 chat tokens + ChatFormat.ChatTokens chatTokens = new ChatFormat.ChatTokens("<|system|>", "<|end|>", "<|user|>", "<|end|>", "<|assistant|>"); - if (useTornadovm) { - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")"); - } - return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } else { - return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } + return new Phi3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); } - // @formatter:on // @formatter:off @Override - public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, - Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { - return new Phi3TornadoWeights( - loadTensorAsFloatArray(tokenEmbeddings), - loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // Combined QKV - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo - loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (not combined in reference) - floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), - FloatArray.fromArray(ropeFreqs.first()), - FloatArray.fromArray(ropeFreqs.second()), - loadTensorAsHalfFloatArray(outputWeight), - outputWeight.ggmlType() + protected Weights createStandardWeights(Map tensorEntries, Phi3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { + float[] ropeFreqsReal = ropeFreqs.first(); + float[] ropeFreqsImag = ropeFreqs.second(); + + final int nl = config.numberOfLayers(); + + return new Phi3StandardWeights( + loadTensor(tokenEmbeddings), // token_embedding_table + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight (as FloatTensor[]) + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // wqkv (combined) + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // rms_ffn_weight (as FloatTensor[]) + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (separate, not combined) + loadTensor(tensorEntries.get("output_norm.weight")), // rms_final_weight (as FloatTensor) + new ArrayFloatTensor(ropeFreqsReal), // freq_cis_real + new ArrayFloatTensor(ropeFreqsImag), // freq_cis_imag + loadTensor(outputWeight), // wcls + outputWeight.ggmlType() // weightType ); } // @formatter:on // @formatter:off @Override - public Weights createStandardWeights(Map tensorEntries, - Configuration config, - Pair ropeFreqs, - GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { - float[] ropeFreqsReal = ropeFreqs.first(); - float[] ropeFreqsImag = ropeFreqs.second(); + protected Weights createTornadoVMWeights(Map tensorEntries, Phi3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { + GGMLType ggmlType = outputWeight.ggmlType(); - return new Phi3StandardWeights( - loadQuantized(tokenEmbeddings), // token_embedding_table - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight (as FloatTensor[]) - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // wqkv (combined) - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // rms_ffn_weight (as FloatTensor[]) - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (separate, not combined) - loadQuantized(tensorEntries.get("output_norm.weight")), // rms_final_weight (as FloatTensor) - new ArrayFloatTensor(ropeFreqsReal), // freq_cis_real - new ArrayFloatTensor(ropeFreqsImag), // freq_cis_imag - loadQuantized(outputWeight), // wcls - outputWeight.ggmlType() // weightType + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { + System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")"); + } + + // Validate supported types + if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { + throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); + } + + final int nl = config.numberOfLayers(); + + // Load all tensors uniformly as TornadoTensor hierarchy + return new Phi3TornadoWeights( + loadTornadoTensorAsFP32(tokenEmbeddings), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32 + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // fp32 + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadTornadoTensor(tensorEntries.get("output_norm.weight")), // fp32 + new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.first())), + new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.second())), + loadTornadoTensor(outputWeight), + ggmlType ); } // @formatter:on diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java index 0fdcce3c..c957c029 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java @@ -1,163 +1,163 @@ package org.beehive.gpullama3.model.loader; -import org.beehive.gpullama3.Options; -import org.beehive.gpullama3.core.model.GGMLType; -import org.beehive.gpullama3.core.model.GGUF; -import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; -import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; -import org.beehive.gpullama3.core.types.Pair; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.GGUF; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor; +import org.beehive.gpullama3.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.auxiliary.Pair; import org.beehive.gpullama3.inference.operation.RoPE; import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights; import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; -import org.beehive.gpullama3.model.Configuration; -import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.format.ChatFormat; import org.beehive.gpullama3.model.format.ChatFormat.ChatTokens; import org.beehive.gpullama3.model.qwen2.Qwen2; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; -import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer; -import org.beehive.gpullama3.tokenizer.impl.Tokenizer; -import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary; +import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tokenizer.Vocabulary; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import java.io.IOException; import java.nio.channels.FileChannel; import java.util.Map; -import static org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary.loadQwen3Vocabulary; +import static org.beehive.gpullama3.model.loader.ModelLoader.*; -public class Qwen2ModelLoader extends ModelLoader { +public class Qwen2ModelLoader extends AbstractModelLoader { - public Qwen2ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { - super(fileChannel, gguf, contextLength, loadWeights, useTornadovm); + public Qwen2ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + super(fileChannel, gguf, contextLength, useTornadovm); } @Override - public Model loadModel() { - Map metadata = gguf.getMetadata(); - String basename = (String) metadata.get("general.basename"); - - String modelName = "DeepSeek-R1-Distill-Qwen".equals(basename) ? "DeepSeek-R1-Distill-Qwen" : "Qwen2.5"; - - try { - // reuse method of Qwen3 - Vocabulary vocabulary = loadQwen3Vocabulary(metadata); - boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename")); - Tokenizer tokenizer = new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen); - - int modelContextLength = (int) metadata.get("qwen2.context_length"); - if (contextLength < 0 || modelContextLength < contextLength) { - contextLength = modelContextLength; - } - - int numberOfKeyValueHeads = metadata.containsKey("qwen2.attention.head_count_kv") ? (int) metadata.get("qwen2.attention.head_count_kv") : (int) metadata.get("qwen2.attention.head_count"); - Qwen2Configuration config = new Qwen2Configuration((int) metadata.get("qwen2.embedding_length"), // dim - (int) metadata.get("qwen2.feed_forward_length"), // hiddendim - (int) metadata.get("qwen2.block_count"), // numberOfLayers - (int) metadata.get("qwen2.attention.head_count"), // numberOfHeads - - numberOfKeyValueHeads, // numberOfKeyValueHeads - numberOfKeyValueHeads, // numberOfHeadsKey - numberOfKeyValueHeads, // numberOfHeadsValue - - vocabulary.size(), modelContextLength, contextLength, false, (float) metadata.get("qwen2.attention.layer_norm_rms_epsilon"), (float) metadata.get("qwen2.rope.freq_base")); - - Weights weights = null; - if (loadWeights) { - Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); - weights = loadWeights(tensorEntries, config); - } - // Qwen2.5-Coder uses <|endoftext|> as stop-token. - ChatTokens chatTokens = isDeepSeekR1DistillQwen - ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") - : new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>"); - return new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); - } catch (IOException e) { - throw new RuntimeException(e); - } + protected Vocabulary loadVocabulary(Map metadata) { + return Vocabulary.loadQwen3Vocabulary(metadata); + } + + @Override + protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { + boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename")); + return new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen); } // @formatter:off @Override - public Weights loadWeights(Map tensorEntries, Configuration config) { - Pair ropeFreqs = RoPE.precomputeFreqsCis( - config.contextLengthModel(), - config.headSize(), - config.ropeTheta(), + protected Qwen2Configuration createConfiguration(Map metadata) { + int modelContextLength = (int) metadata.get("qwen2.context_length"); + int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength; + + int numberOfKeyValueHeads = metadata.containsKey("qwen2.attention.head_count_kv") ? (int) metadata.get("qwen2.attention.head_count_kv") : (int) metadata.get("qwen2.attention.head_count"); + int vocabSize = vocabulary.size(); + + return new Qwen2Configuration( + (int) metadata.get("qwen2.embedding_length"), // dim + (int) metadata.get("qwen2.feed_forward_length"), // hiddendim + (int) metadata.get("qwen2.block_count"), // numberOfLayers + (int) metadata.get("qwen2.attention.head_count"), // numberOfHeads + + numberOfKeyValueHeads, // numberOfKeyValueHeads + numberOfKeyValueHeads, // numberOfHeadsKey + numberOfKeyValueHeads, // numberOfHeadsValue + + vocabSize, + modelContextLength, + finalContextLength, false, - 8, - 1, - 3, - 8192 + (float) metadata.get("qwen2.attention.layer_norm_rms_epsilon"), + (float) metadata.get("qwen2.rope.freq_base") ); + } + // @formatter:on - GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight"); - GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings); + @Override + protected Pair precomputeRopeFrequencies(Qwen2Configuration config) { + return RoPE.precomputeFreqsCis(config.contextLengthModel(), config.headSize(), config.ropeTheta(), false, 8, 1, 3, 8192); + } - if (useTornadovm) { - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")"); - } - return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } else { - return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } + // @formatter:off + @Override + protected Qwen2 createModel(Qwen2Configuration config, Tokenizer tokenizer, Weights weights) { + Map metadata = gguf.getMetadata(); + boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename")); + // Qwen2.5-Coder uses <|endoftext|> as stop-token. + ChatTokens chatTokens = isDeepSeekR1DistillQwen ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") + : new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>"); + return new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); } + // @formatter:on + // @formatter:off @Override - public Weights createStandardWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { + protected Weights createStandardWeights(Map tensorEntries, Qwen2Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + + final int nl = config.numberOfLayers(); + return new Qwen2StandardWeights( - loadQuantized(tokenEmbeddings), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), - - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.bias")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.bias")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.bias")), - - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), - loadQuantized(tensorEntries.get("output_norm.weight")), + loadTensor(tokenEmbeddings), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.bias")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.bias")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.bias")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadTensor(tensorEntries.get("output_norm.weight")), new ArrayFloatTensor(ropeFreqs.first()), new ArrayFloatTensor(ropeFreqs.second()), - loadQuantized(outputWeight), - outputWeight.ggmlType()); + loadTensor(outputWeight), + outputWeight.ggmlType() + ); } + // @formatter:on + // @formatter:off @Override - public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { + protected Weights createTornadoVMWeights(Map tensorEntries, Qwen2Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + GGMLType ggmlType = outputWeight.ggmlType(); + + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { + System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")"); + } + + // Validate supported types + if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { + throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); + } + + final int nl = config.numberOfLayers(); + + // Load all tensors uniformly as TornadoTensor hierarchy return new Qwen2TornadoWeights( - loadTensorAsFloatArray(tokenEmbeddings), - loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), - // Qwen2-specific: qkv bias - loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.bias")), - loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.bias")), - loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.bias")), - - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), - loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1 - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2 - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3 - floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), - FloatArray.fromArray(ropeFreqs.first()), - FloatArray.fromArray(ropeFreqs.second()), - loadTensorAsHalfFloatArray(outputWeight), - outputWeight.ggmlType() + loadTornadoTensorAsFP32(tokenEmbeddings), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32 + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + // Qwen2-specific: qkv bias (always F32) + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.bias")), // fp32 + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.bias")), // fp32 + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.bias")), // fp32 + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // fp32 + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadTornadoTensor(tensorEntries.get("output_norm.weight")), // fp32 + new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.first())), + new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.second())), + loadTornadoTensor(outputWeight), + ggmlType ); - } - // @formatter:on + } + // @formatter:off } diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java index 8671b8ef..008af2b3 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java @@ -1,175 +1,162 @@ package org.beehive.gpullama3.model.loader; -import org.beehive.gpullama3.Options; -import org.beehive.gpullama3.auxiliary.Timer; -import org.beehive.gpullama3.core.model.GGMLType; -import org.beehive.gpullama3.core.model.GGUF; -import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; -import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; -import org.beehive.gpullama3.core.types.Pair; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.GGUF; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor; +import org.beehive.gpullama3.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.auxiliary.Pair; import org.beehive.gpullama3.inference.operation.RoPE; import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.inference.weights.standard.Qwen3StandardWeights; import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; -import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.format.ChatFormat; import org.beehive.gpullama3.model.format.ChatFormat.ChatTokens; import org.beehive.gpullama3.model.qwen3.Qwen3; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; -import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer; -import org.beehive.gpullama3.tokenizer.impl.Tokenizer; -import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary; +import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tokenizer.Vocabulary; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import java.io.IOException; import java.nio.channels.FileChannel; import java.util.Map; -import static org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary.loadQwen3Vocabulary; +import static org.beehive.gpullama3.model.loader.ModelLoader.*; +import static org.beehive.gpullama3.tokenizer.Vocabulary.loadQwen3Vocabulary; -public class Qwen3ModelLoader extends ModelLoader { +public class Qwen3ModelLoader extends AbstractModelLoader { - public Qwen3ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { - super(fileChannel, gguf, contextLength, loadWeights, useTornadovm); + public Qwen3ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + super(fileChannel, gguf, contextLength, useTornadovm); } - // @formatter:off @Override - public Qwen3 loadModel() { - try { - Map metadata = gguf.getMetadata(); - - Vocabulary vocabulary = loadQwen3Vocabulary(metadata); - boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename")); - Tokenizer tokenizer = new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen); - - int modelContextLength = (int) metadata.get("qwen3.context_length"); - if (contextLength < 0 || modelContextLength < contextLength) { - contextLength = modelContextLength; - } - - Qwen3Configuration config = new Qwen3Configuration( - (int) metadata.get("qwen3.embedding_length"), - (int) metadata.get("qwen3.feed_forward_length"), - (int) metadata.get("qwen3.block_count"), - (int) metadata.get("qwen3.attention.head_count"), - - metadata.containsKey("qwen3.attention.head_count_kv") - ? (int) metadata.get("qwen3.attention.head_count_kv") - : (int) metadata.get("qwen3.attention.head_count"), - (int) metadata.get("qwen3.attention.key_length"), - (int) metadata.get("qwen3.attention.value_length"), - - vocabulary.size(), - modelContextLength, contextLength, - false, - (float) metadata.get("qwen3.attention.layer_norm_rms_epsilon"), - (float) metadata.get("qwen3.rope.freq_base") - ); - - Weights weights = null; - if (loadWeights) { - Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); - weights = loadWeights(tensorEntries, config); - } - // Qwen2.5-coder uses <|endoftext|> as stop-token. - ChatTokens chatTokens = isDeepSeekR1DistillQwen ? - new ChatTokens( "<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") : - new ChatTokens( "<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>"); - return new Qwen3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); - } catch (IOException e) { - throw new RuntimeException(e); - } + protected Vocabulary loadVocabulary(Map metadata) { + return loadQwen3Vocabulary(metadata); + } + + @Override + protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { + boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename")); + return new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen); } - // @formatter:on // @formatter:off @Override - public Weights loadWeights(Map tensorEntries, Configuration config) { - Pair ropeFreqs = RoPE.precomputeFreqsCis( - config.contextLengthModel(), - config.numberOfHeadsKey(), - config.ropeTheta(), + protected Qwen3Configuration createConfiguration(Map metadata) { + int modelContextLength = (int) metadata.get("qwen3.context_length"); + int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength; + + int vocabSize = vocabulary.size(); + + return new Qwen3Configuration( + (int) metadata.get("qwen3.embedding_length"), + (int) metadata.get("qwen3.feed_forward_length"), + (int) metadata.get("qwen3.block_count"), + (int) metadata.get("qwen3.attention.head_count"), + + metadata.containsKey("qwen3.attention.head_count_kv") ? + (int) metadata.get("qwen3.attention.head_count_kv") : + (int) metadata.get("qwen3.attention.head_count"), + (int) metadata.get("qwen3.attention.key_length"), + (int) metadata.get("qwen3.attention.value_length"), + + vocabSize, + modelContextLength, + finalContextLength, false, - 0, - 0, - 0, - 0 + (float) metadata.get("qwen3.attention.layer_norm_rms_epsilon"), + (float) metadata.get("qwen3.rope.freq_base") ); - - GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight"); - GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings); - - if (useTornadovm) { - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")"); - } - return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } else { - return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } } // @formatter:on + @Override + protected Pair precomputeRopeFrequencies(Qwen3Configuration config) { + return RoPE.precomputeFreqsCis(config.contextLengthModel(), config.numberOfHeadsKey(), config.ropeTheta(), false, 0, 0, 0, 0); + } + // @formatter:off @Override - public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { - return new Qwen3TornadoWeights( - loadTensorAsFloatArray(tokenEmbeddings), - loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), - loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")), // attnKNorm - loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")), // attnQNorm - loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1 - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2 - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3 - floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), - FloatArray.fromArray(ropeFreqs.first()), - FloatArray.fromArray(ropeFreqs.second()), - loadTensorAsHalfFloatArray(outputWeight), - outputWeight.ggmlType() - ); + protected Qwen3 createModel(Qwen3Configuration config, Tokenizer tokenizer, Weights weights) { + Map metadata = gguf.getMetadata(); + boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename")); + // Qwen2.5-coder uses <|endoftext|> as stop-token. + ChatTokens chatTokens = isDeepSeekR1DistillQwen ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") + : new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>"); + return new Qwen3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); } - // @formatter:on + // @formatter:off // @formatter:off @Override - public Weights createStandardWeights(Map tensorEntries, - Configuration config, - Pair ropeFreqs, - GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { + protected Weights createStandardWeights(Map tensorEntries, Qwen3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { float[] ropeFreqsReal = ropeFreqs.first(); float[] ropeFreqsImag = ropeFreqs.second(); + + final int nl = config.numberOfLayers(); + return new Qwen3StandardWeights( - loadQuantized(tokenEmbeddings), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), // wq - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), // wk - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), // wv - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo - - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")), // attnKNorm - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")), // attnQNorm - - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), //rms_ffn_weight - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1 - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2 - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3 - loadQuantized(tensorEntries.get("output_norm.weight")), // rms_final_weight + loadTensor(tokenEmbeddings), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), // wq + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), // wk + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), // wv + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")), // attnKNorm + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")), // attnQNorm + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), //rms_ffn_weight + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1 + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2 + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3 + loadTensor(tensorEntries.get("output_norm.weight")), // rms_final_weight new ArrayFloatTensor(ropeFreqsReal), new ArrayFloatTensor(ropeFreqsImag), tensorEntries.containsKey("output.weight") - ? ModelLoader.loadQuantized(tensorEntries.get("output.weight")) - : loadQuantized(tokenEmbeddings), // weights are shared + ? ModelLoader.loadTensor(tensorEntries.get("output.weight")) + : loadTensor(tokenEmbeddings), // weights are shared null ); } // @formatter:on + + // @formatter:off + @Override + protected Weights createTornadoVMWeights(Map tensorEntries, Qwen3Configuration config, + Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { + System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")"); + } + + GGMLType ggmlType = outputWeight.ggmlType(); + + final int nl = config.numberOfLayers(); + + return new Qwen3TornadoWeights( + loadTornadoTensorAsFP32(tokenEmbeddings), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32 + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + // Qwen3-specific: attnKNorm and attnQNorm (always F32) + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")), // fp32 + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")), // fp32 + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // fp32 + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadTornadoTensor(tensorEntries.get("output_norm.weight")), // fp32 + new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.first())), + new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.second())), + loadTornadoTensor(outputWeight), + ggmlType + ); + + } + // @formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java b/src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java index 8176b85b..931f4317 100644 --- a/src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java +++ b/src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java @@ -9,8 +9,8 @@ import org.beehive.gpullama3.model.AbstractModel; import org.beehive.gpullama3.model.ModelType; import org.beehive.gpullama3.model.format.ChatFormat; -import org.beehive.gpullama3.tokenizer.impl.MistralTokenizer; -import org.beehive.gpullama3.tokenizer.impl.Tokenizer; +import org.beehive.gpullama3.tokenizer.MistralTokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import java.util.List; diff --git a/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java b/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java index 1ee4ce46..3328a55f 100644 --- a/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java +++ b/src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java @@ -9,8 +9,8 @@ import org.beehive.gpullama3.model.AbstractModel; import org.beehive.gpullama3.model.ModelType; import org.beehive.gpullama3.model.format.ChatFormat; -import org.beehive.gpullama3.tokenizer.impl.Phi3Tokenizer; -import org.beehive.gpullama3.tokenizer.impl.Tokenizer; +import org.beehive.gpullama3.tokenizer.Phi3Tokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import java.util.List; diff --git a/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java b/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java index e8fcb581..92fdf564 100644 --- a/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java +++ b/src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java @@ -9,8 +9,8 @@ import org.beehive.gpullama3.model.AbstractModel; import org.beehive.gpullama3.model.ModelType; import org.beehive.gpullama3.model.format.ChatFormat; -import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer; -import org.beehive.gpullama3.tokenizer.impl.Tokenizer; +import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import java.util.List; diff --git a/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java b/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java index bf90c13d..cf16b3cc 100644 --- a/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java +++ b/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java @@ -9,8 +9,8 @@ import org.beehive.gpullama3.model.AbstractModel; import org.beehive.gpullama3.model.ModelType; import org.beehive.gpullama3.model.format.ChatFormat; -import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer; -import org.beehive.gpullama3.tokenizer.impl.Tokenizer; +import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import java.util.List; diff --git a/src/main/java/org/beehive/gpullama3/core/types/Float16.java b/src/main/java/org/beehive/gpullama3/tensor/Float16.java similarity index 62% rename from src/main/java/org/beehive/gpullama3/core/types/Float16.java rename to src/main/java/org/beehive/gpullama3/tensor/Float16.java index 6639a41b..fb171317 100644 --- a/src/main/java/org/beehive/gpullama3/core/types/Float16.java +++ b/src/main/java/org/beehive/gpullama3/tensor/Float16.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.core.types; +package org.beehive.gpullama3.tensor; public final class Float16 { public static final int BYTES = 2; diff --git a/src/main/java/org/beehive/gpullama3/core/model/tensor/GGMLTensorEntry.java b/src/main/java/org/beehive/gpullama3/tensor/GGMLTensorEntry.java similarity index 67% rename from src/main/java/org/beehive/gpullama3/core/model/tensor/GGMLTensorEntry.java rename to src/main/java/org/beehive/gpullama3/tensor/GGMLTensorEntry.java index 8098aa11..9af9b10f 100644 --- a/src/main/java/org/beehive/gpullama3/core/model/tensor/GGMLTensorEntry.java +++ b/src/main/java/org/beehive/gpullama3/tensor/GGMLTensorEntry.java @@ -1,6 +1,4 @@ -package org.beehive.gpullama3.core.model.tensor; - -import org.beehive.gpullama3.core.model.GGMLType; +package org.beehive.gpullama3.tensor; import java.lang.foreign.MemorySegment; diff --git a/src/main/java/org/beehive/gpullama3/core/model/GGMLType.java b/src/main/java/org/beehive/gpullama3/tensor/GGMLType.java similarity index 98% rename from src/main/java/org/beehive/gpullama3/core/model/GGMLType.java rename to src/main/java/org/beehive/gpullama3/tensor/GGMLType.java index 972a4f52..f1888bb2 100644 --- a/src/main/java/org/beehive/gpullama3/core/model/GGMLType.java +++ b/src/main/java/org/beehive/gpullama3/tensor/GGMLType.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.core.model; +package org.beehive.gpullama3.tensor; public enum GGMLType { // Floating point types diff --git a/src/main/java/org/beehive/gpullama3/core/model/GGUF.java b/src/main/java/org/beehive/gpullama3/tensor/GGUF.java similarity index 64% rename from src/main/java/org/beehive/gpullama3/core/model/GGUF.java rename to src/main/java/org/beehive/gpullama3/tensor/GGUF.java index c32cdc1d..9cdc5b7d 100644 --- a/src/main/java/org/beehive/gpullama3/core/model/GGUF.java +++ b/src/main/java/org/beehive/gpullama3/tensor/GGUF.java @@ -1,15 +1,14 @@ -package org.beehive.gpullama3.core.model; +package org.beehive.gpullama3.tensor; -import org.beehive.gpullama3.auxiliary.Timer; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; -import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; -import org.beehive.gpullama3.core.types.MetadataValueType; -import org.beehive.gpullama3.core.types.Pair; +import org.beehive.gpullama3.tensor.standard.FloatTensor; +import org.beehive.gpullama3.auxiliary.Pair; +import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray; import java.io.FileNotFoundException; import java.io.IOException; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.channels.FileChannel; @@ -20,7 +19,11 @@ import java.util.List; import java.util.Map; +import static java.nio.file.StandardOpenOption.READ; +import static java.nio.file.StandardOpenOption.WRITE; + public final class GGUF { + private static FileChannel fileChannel; private static final int GGUF_MAGIC = 0x46554747; private static final int DEFAULT_ALIGNMENT = 32; // must be a power of 2 private static final List SUPPORTED_GGUF_VERSIONS = List.of(2, 3); @@ -37,38 +40,159 @@ public final class GGUF { private Map tensorInfos; private long tensorDataOffset; - public static GGUF loadModel(Path modelPath) throws IOException { + public static GGUF loadGGUFMetadata(Path modelPath) throws IOException { // file existence check if (!Files.exists(modelPath)) { throw new FileNotFoundException("Model file not found: " + modelPath); } - // second check to make sure that nothing goes wrong during model loading - try (FileChannel fileChannel = FileChannel.open(modelPath); - ) { + // Open file + try { + fileChannel = FileChannel.open(modelPath, READ, WRITE); + // Ensure we start reading from the beginning of the file + fileChannel.position(0); + } catch (Exception e) { + throw new RuntimeException("Failed to open file channel for " + modelPath, e); + } + + // Read and store the gguf metadata + try { GGUF gguf = new GGUF(); - gguf.loadModelImpl(fileChannel); + // The header of the file. + gguf.readHeader(fileChannel); // gguf_header_t header; + // Tensor infos, which can be used to locate the tensor data. + // gguf_tensor_info_t tensor_infos[header.tensor_count]; + gguf.tensorInfos = HashMap.newHashMap(gguf.tensorCount); + for (int i = 0; i < gguf.tensorCount; ++i) { + GGUF.GGUFTensorInfo ti = gguf.readTensorInfo(fileChannel); + assert !gguf.tensorInfos.containsKey(ti.name); + gguf.tensorInfos.put(ti.name, ti); + } + // Padding to the nearest multiple of `ALIGNMENT`. + // uint8_t _padding[ALIGNMENT - (sizeof(header + tensor_infos) % ALIGNMENT)]; + long _padding = (gguf.getAlignment() - (fileChannel.position() % gguf.getAlignment())) % gguf.getAlignment(); + fileChannel.position(fileChannel.position() + _padding); + // Tensor data. + // + // This is arbitrary binary data corresponding to the weights of the model. This data should be close + // or identical to the data in the original model file, but may be different due to quantization or + // other optimizations for inference. Any such deviations should be recorded in the metadata or as + // part of the architecture definition. + // + // Each tensor's data must be stored within this array, and located through its `tensor_infos` entry. + // The offset of each tensor's data must be a multiple of `ALIGNMENT`, and the space between tensors + // should be padded to `ALIGNMENT` bytes. + // uint8_t tensor_data[]; + gguf.tensorDataOffset = fileChannel.position(); return gguf; } catch (Exception e) { throw new RuntimeException("Unexpected error while loading GGUF model from " + modelPath, e); } } - public static Map loadTensors(FileChannel fileChannel, long tensorDataOffset, Map tensorInfos) throws IOException { + /** + * Loads tensor data from a given file channel based on the tensor metadata information. + * The mapping is read-only and creates standard memory segments for each tensor. + * + * @param fileChannel the channel from which tensor storage is read + * @param tensorDataOffset the absolute byte offset of the GGUF tensor-data section + * @param tensorInfos metadata describing all GGUF tensors + * @return a map from tensor name to {@link GGMLTensorEntry} containing + * standard memory segments for each tensor + * @throws IOException if memory mapping fails or the channel cannot be read + */ + public static Map loadTensorsStandard(FileChannel fileChannel, long tensorDataOffset, Map tensorInfos) throws IOException { Arena arena = Arena.ofAuto(); - MemorySegment tensorData = fileChannel.map(FileChannel.MapMode.READ_ONLY, tensorDataOffset, fileChannel.size() - tensorDataOffset, arena); + + // absolute file offset where the tensor-data section begins + long mappingOffset = tensorDataOffset; + // size of the entire tensor-data section + long mappingSize = fileChannel.size() - tensorDataOffset; + + MemorySegment tensorData = fileChannel.map(FileChannel.MapMode.READ_ONLY, mappingOffset, mappingSize, arena); + Map tensorEntries = HashMap.newHashMap(tensorInfos.size()); + for (Map.Entry entry : tensorInfos.entrySet()) { GGUFTensorInfo ti = entry.getValue(); + + // skip rope_freqs.weight (not needed for inference) + if (ti.name().equals("rope_freqs.weight")) { + continue; + } + int numberOfElements = FloatTensor.numberOfElements(ti.dimensions()); int sizeInBytes = Math.toIntExact(ti.ggmlType().byteSizeFor(numberOfElements)); - MemorySegment memorySegment = tensorData.asSlice(ti.offset(), sizeInBytes); + + // per-tensor slice offset; ti.offset() is relative to tensor-data start + long offset = ti.offset(); + + // per-tensor slice segment + MemorySegment memorySegment = tensorData.asSlice(offset, sizeInBytes); + tensorEntries.put(ti.name(), new GGMLTensorEntry(tensorData, ti.name(), ti.ggmlType(), ti.dimensions(), memorySegment)); } return tensorEntries; } + /** + * Loads GGUF tensor data using a TornadoVM-compatible memory layout. + * + *

This method parses the GGUF tensor list and memory-maps each tensor + * in {@link TornadoNativeArray} layout directly from the underlying {@link FileChannel}. + * For compatibility with {@link TornadoNativeArray} layout, an additional header is required at + * the start of each tensor region. To satisfy this requirement, each tensor + * is mapped using {@link FileChannel.MapMode#PRIVATE} starting 16 bytes + * before the actual tensor position, providing a writable header region + * without modifying the underlying GGUF file.

+ * + * @param fileChannel the channel from which tensor storage is read + * @param tensorDataOffset the absolute byte offset of the GGUF tensor-data section + * @param tensorInfos metadata describing all GGUF tensors + * @return a map from tensor name to {@link GGMLTensorEntry} containing + * TornadoVM-compatible memory segments for each tensor + * @throws IOException if memory mapping fails or the channel cannot be read + */ + public static Map loadTensorsTornado(FileChannel fileChannel, long tensorDataOffset, Map tensorInfos) throws IOException { + + Arena arena = Arena.ofAuto(); + Map tensorEntries = HashMap.newHashMap(tensorInfos.size()); + + for (Map.Entry entry : tensorInfos.entrySet()) { + GGUFTensorInfo ti = entry.getValue(); + + // skip rope_freqs.weight (not required for inference) + if (ti.name().equals("rope_freqs.weight")) { + continue; + } + + int numberOfElements = FloatTensor.numberOfElements(ti.dimensions()); + int sizeInBytes = Math.toIntExact(ti.ggmlType().byteSizeFor(numberOfElements)); + + // absolute tensor offset - relative to start of the file + long mappingOffset = tensorDataOffset + ti.offset(); + + // create memory segment in TornadoVM NativeArray layout: + // TornadoNativeArray.ARRAY_HEADER (16-byte) + tensor data + long headerBytes = TornadoNativeArray.ARRAY_HEADER; + + // start 16 bytes before the tensor position to include header space + long offset = mappingOffset - headerBytes; + long size = sizeInBytes + headerBytes; + MemorySegment memorySegment = fileChannel.map(FileChannel.MapMode.PRIVATE, offset, size, arena); + + // zero out the 16-byte header + for (int i = 0; i < headerBytes; i++) { + memorySegment.set(ValueLayout.JAVA_BYTE, i, (byte) 0); + } + + // store tornado-compatible segment + tensorEntries.put(ti.name(), new GGMLTensorEntry(memorySegment, ti.name(), ti.ggmlType(), ti.dimensions(), memorySegment)); + } + return tensorEntries; + } + public Map getTensorInfos() { return tensorInfos; } @@ -81,34 +205,8 @@ public Map getMetadata() { return metadata; } - private void loadModelImpl(FileChannel fileChannel) throws IOException { - // The header of the file. - readHeader(fileChannel); // gguf_header_t header; - // Tensor infos, which can be used to locate the tensor data. - // gguf_tensor_info_t tensor_infos[header.tensor_count]; - this.tensorInfos = HashMap.newHashMap(tensorCount); - for (int i = 0; i < tensorCount; ++i) { - GGUF.GGUFTensorInfo ti = readTensorInfo(fileChannel); - assert !tensorInfos.containsKey(ti.name); - tensorInfos.put(ti.name, ti); - } - // Padding to the nearest multiple of `ALIGNMENT`. - // uint8_t _padding[ALIGNMENT - (sizeof(header + tensor_infos) % ALIGNMENT)]; - //long _padding = -fileChannel.position() & (ALIGNMENT - 1); - long _padding = getAlignment() - (fileChannel.position() % getAlignment()); - fileChannel.position(fileChannel.position() + _padding); - // Tensor data. - // - // This is arbitrary binary data corresponding to the weights of the model. This data should be close - // or identical to the data in the original model file, but may be different due to quantization or - // other optimizations for inference. Any such deviations should be recorded in the metadata or as - // part of the architecture definition. - // - // Each tensor's data must be stored within this array, and located through its `tensor_infos` entry. - // The offset of each tensor's data must be a multiple of `ALIGNMENT`, and the space between tensors - // should be padded to `ALIGNMENT` bytes. - // uint8_t tensor_data[]; - this.tensorDataOffset = fileChannel.position(); + public FileChannel getFileChannel() { + return fileChannel; } private GGMLType readGGMLType(FileChannel fileChannel) throws IOException { diff --git a/src/main/java/org/beehive/gpullama3/core/types/MetadataValueType.java b/src/main/java/org/beehive/gpullama3/tensor/MetadataValueType.java similarity index 97% rename from src/main/java/org/beehive/gpullama3/core/types/MetadataValueType.java rename to src/main/java/org/beehive/gpullama3/tensor/MetadataValueType.java index 911f364d..f7e08346 100644 --- a/src/main/java/org/beehive/gpullama3/core/types/MetadataValueType.java +++ b/src/main/java/org/beehive/gpullama3/tensor/MetadataValueType.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.core.types; +package org.beehive.gpullama3.tensor; public enum MetadataValueType { // The value is a 8-bit unsigned integer. diff --git a/src/main/java/org/beehive/gpullama3/core/model/tensor/ArrayFloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/ArrayFloatTensor.java similarity index 93% rename from src/main/java/org/beehive/gpullama3/core/model/tensor/ArrayFloatTensor.java rename to src/main/java/org/beehive/gpullama3/tensor/standard/ArrayFloatTensor.java index 1214967f..d25623cc 100644 --- a/src/main/java/org/beehive/gpullama3/core/model/tensor/ArrayFloatTensor.java +++ b/src/main/java/org/beehive/gpullama3/tensor/standard/ArrayFloatTensor.java @@ -1,6 +1,6 @@ -package org.beehive.gpullama3.core.model.tensor; +package org.beehive.gpullama3.tensor.standard; -import org.beehive.gpullama3.core.model.GGMLType; +import org.beehive.gpullama3.tensor.GGMLType; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.VectorSpecies; diff --git a/src/main/java/org/beehive/gpullama3/core/model/tensor/F16FloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/FP16FloatTensor.java similarity index 92% rename from src/main/java/org/beehive/gpullama3/core/model/tensor/F16FloatTensor.java rename to src/main/java/org/beehive/gpullama3/tensor/standard/FP16FloatTensor.java index 9e7ec8bf..88587072 100644 --- a/src/main/java/org/beehive/gpullama3/core/model/tensor/F16FloatTensor.java +++ b/src/main/java/org/beehive/gpullama3/tensor/standard/FP16FloatTensor.java @@ -1,6 +1,6 @@ -package org.beehive.gpullama3.core.model.tensor; +package org.beehive.gpullama3.tensor.standard; -import org.beehive.gpullama3.core.model.GGMLType; +import org.beehive.gpullama3.tensor.GGMLType; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.ShortVector; import jdk.incubator.vector.VectorOperators; @@ -9,12 +9,12 @@ import java.lang.foreign.MemorySegment; import java.nio.ByteOrder; -public final class F16FloatTensor extends FloatTensor { +public final class FP16FloatTensor extends FloatTensor { final int size; final MemorySegment memorySegment; - public F16FloatTensor(int size, MemorySegment memorySegment) { + public FP16FloatTensor(int size, MemorySegment memorySegment) { this.size = size; this.memorySegment = memorySegment; } @@ -59,7 +59,7 @@ public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { } } - private static float vectorDot(F16FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) { + private static float vectorDot(FP16FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) { assert S_SPECIES_HALF.length() == F_SPECIES.length(); FloatVector val = FloatVector.zero(F_SPECIES); int upperBound = F_SPECIES.loopBound(size); diff --git a/src/main/java/org/beehive/gpullama3/core/model/tensor/F32FloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/FP32FloatTensor.java similarity index 82% rename from src/main/java/org/beehive/gpullama3/core/model/tensor/F32FloatTensor.java rename to src/main/java/org/beehive/gpullama3/tensor/standard/FP32FloatTensor.java index f188e9f5..2deff33e 100644 --- a/src/main/java/org/beehive/gpullama3/core/model/tensor/F32FloatTensor.java +++ b/src/main/java/org/beehive/gpullama3/tensor/standard/FP32FloatTensor.java @@ -1,17 +1,17 @@ -package org.beehive.gpullama3.core.model.tensor; +package org.beehive.gpullama3.tensor.standard; -import org.beehive.gpullama3.core.model.GGMLType; +import org.beehive.gpullama3.tensor.GGMLType; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.VectorSpecies; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; -public final class F32FloatTensor extends FloatTensor { +public final class FP32FloatTensor extends FloatTensor { final int size; final MemorySegment segment; - public F32FloatTensor(int size, MemorySegment segment) { + public FP32FloatTensor(int size, MemorySegment segment) { this.size = size; this.segment = segment; } diff --git a/src/main/java/org/beehive/gpullama3/core/model/tensor/FloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/FloatTensor.java similarity index 98% rename from src/main/java/org/beehive/gpullama3/core/model/tensor/FloatTensor.java rename to src/main/java/org/beehive/gpullama3/tensor/standard/FloatTensor.java index f0c7f2cf..d91ab964 100644 --- a/src/main/java/org/beehive/gpullama3/core/model/tensor/FloatTensor.java +++ b/src/main/java/org/beehive/gpullama3/tensor/standard/FloatTensor.java @@ -1,7 +1,7 @@ -package org.beehive.gpullama3.core.model.tensor; +package org.beehive.gpullama3.tensor.standard; import org.beehive.gpullama3.auxiliary.Parallel; -import org.beehive.gpullama3.core.model.GGMLType; +import org.beehive.gpullama3.tensor.GGMLType; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.VectorShape; import jdk.incubator.vector.VectorSpecies; diff --git a/src/main/java/org/beehive/gpullama3/core/model/tensor/Q4_0FloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/Q4_0FloatTensor.java similarity index 97% rename from src/main/java/org/beehive/gpullama3/core/model/tensor/Q4_0FloatTensor.java rename to src/main/java/org/beehive/gpullama3/tensor/standard/Q4_0FloatTensor.java index 8396e611..eadfc1ab 100644 --- a/src/main/java/org/beehive/gpullama3/core/model/tensor/Q4_0FloatTensor.java +++ b/src/main/java/org/beehive/gpullama3/tensor/standard/Q4_0FloatTensor.java @@ -1,8 +1,8 @@ -package org.beehive.gpullama3.core.model.tensor; +package org.beehive.gpullama3.tensor.standard; import org.beehive.gpullama3.LlamaApp; -import org.beehive.gpullama3.core.model.GGMLType; -import org.beehive.gpullama3.core.types.Float16; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.Float16; import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.VectorOperators; diff --git a/src/main/java/org/beehive/gpullama3/core/model/tensor/Q8_0FloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/Q8_0FloatTensor.java similarity index 97% rename from src/main/java/org/beehive/gpullama3/core/model/tensor/Q8_0FloatTensor.java rename to src/main/java/org/beehive/gpullama3/tensor/standard/Q8_0FloatTensor.java index 63a214af..9067bde0 100644 --- a/src/main/java/org/beehive/gpullama3/core/model/tensor/Q8_0FloatTensor.java +++ b/src/main/java/org/beehive/gpullama3/tensor/standard/Q8_0FloatTensor.java @@ -1,8 +1,8 @@ -package org.beehive.gpullama3.core.model.tensor; +package org.beehive.gpullama3.tensor.standard; -import org.beehive.gpullama3.core.model.GGMLType; -import org.beehive.gpullama3.core.types.Float16; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.Float16; import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.VectorOperators; diff --git a/src/main/java/org/beehive/gpullama3/tensor/tornado/FP16TornadoTensor.java b/src/main/java/org/beehive/gpullama3/tensor/tornado/FP16TornadoTensor.java new file mode 100644 index 00000000..bcf1e3df --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tensor/tornado/FP16TornadoTensor.java @@ -0,0 +1,30 @@ +package org.beehive.gpullama3.tensor.tornado; + +import org.beehive.gpullama3.tensor.GGMLType; +import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; + +import java.lang.foreign.MemorySegment; + +public class FP16TornadoTensor extends TornadoTensor { + private final HalfFloatArray tornadoNativeArray; + + public FP16TornadoTensor(HalfFloatArray halfFloatArray) { + this.tornadoNativeArray = halfFloatArray; + } + + public static FP16TornadoTensor fromTornadoMemorySegment(MemorySegment segment) { + return new FP16TornadoTensor(HalfFloatArray.fromSegmentShallow(segment)); + } + + @Override + public HalfFloatArray asHalfFloatArray() { + return tornadoNativeArray; + } + + @Override + public GGMLType type() { + return GGMLType.F16; + } +} + diff --git a/src/main/java/org/beehive/gpullama3/tensor/tornado/FP32TornadoTensor.java b/src/main/java/org/beehive/gpullama3/tensor/tornado/FP32TornadoTensor.java new file mode 100644 index 00000000..a1520c36 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tensor/tornado/FP32TornadoTensor.java @@ -0,0 +1,29 @@ +package org.beehive.gpullama3.tensor.tornado; + +import org.beehive.gpullama3.tensor.GGMLType; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +import java.lang.foreign.MemorySegment; + +public class FP32TornadoTensor extends TornadoTensor { + private final FloatArray tornadoNativeArray; + + public FP32TornadoTensor(FloatArray floatArray) { + this.tornadoNativeArray = floatArray; + } + + public static FP32TornadoTensor fromTornadoMemorySegment(MemorySegment segment) { + return new FP32TornadoTensor(FloatArray.fromSegmentShallow(segment)); + } + + @Override + public FloatArray asFloatArray() { + return tornadoNativeArray; + } + + @Override + public GGMLType type() { + return GGMLType.F32; + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tensor/tornado/Q8_0TornadoTensor.java b/src/main/java/org/beehive/gpullama3/tensor/tornado/Q8_0TornadoTensor.java new file mode 100644 index 00000000..296e7bfa --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tensor/tornado/Q8_0TornadoTensor.java @@ -0,0 +1,195 @@ +package org.beehive.gpullama3.tensor.tornado; + +import org.beehive.gpullama3.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.standard.FloatTensor; +import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.api.types.arrays.*; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteOrder; +import java.util.concurrent.*; +import java.util.stream.IntStream; + +public class Q8_0TornadoTensor extends TornadoTensor { + + private final int size; + private final HalfFloatArray scales; // One per 32-element block + private final Int8Array quants; // Quantized int8 values + private MemorySegment segment; + + public Q8_0TornadoTensor(int size, HalfFloatArray scales, Int8Array quants, MemorySegment segment) { + this.size = size; + this.scales = scales; + this.quants = quants; + this.segment = segment; + } + + public int getSize() { + return size; + } + + /** + * Returns the scale factors for GPU kernels. + * + * @return HalfFloatArray containing fp16 scale factors + */ + public HalfFloatArray getScales() { + return scales; + } + + /** + * Returns the quantized values for GPU kernels. + * + * @return Int8Array containing quantized int8 values + */ + public Int8Array getQuants() { + return quants; + } + + @Override + public GGMLType type() { + return GGMLType.Q8_0; + } + + public MemorySegment asMemorySegment() { + return segment; + } + + /** + * Dequantizes and returns a single float value. + * + * @param index Element index + * @return Dequantized float value + */ + public float getFloat(int index) { + assert 0 <= index; + int blockIdx = index / GGMLType.Q8_0.getBlockSize(); + float scale = scales.get(blockIdx).getFloat32(); + byte quant = quants.get(index); + return quant * scale; + } + + /** + * Creates a Q8_0TornadoTensor from a GGMLTensorEntry (original implementation). + */ + public static Q8_0TornadoTensor createAsQ8_0(GGMLTensorEntry entry) { + if (entry.ggmlType() != GGMLType.Q8_0) { + throw new IllegalArgumentException("Expected Q8_0 tensor, got: " + entry.ggmlType() + " for tensor: " + entry.name()); + } + + int[] shape = entry.shape(); + int size = FloatTensor.numberOfElements(shape); + int numBlocks = size / GGMLType.Q8_0.getBlockSize(); + + if (size % GGMLType.Q8_0.getBlockSize() != 0) { + throw new IllegalArgumentException("Q8_0 tensor size must be multiple of " + GGMLType.Q8_0.getBlockSize() + ", got: " + size + " for tensor: " + entry.name()); + } + + // TODO: fix Q8_0 loading in tornado layoyt + // currently we end up to hack it by removing + // tornado header from memory segment + MemorySegment q8Segment = entry.memorySegment().asSlice(TornadoNativeArray.ARRAY_HEADER); + + // allocate the arrays for quantized data (int8) and scales (fp16) + HalfFloatArray scales = new HalfFloatArray(numBlocks); + Int8Array quants = new Int8Array(size); + + // unpack Q8_0 blocks: [2 bytes fp16 scale][32 bytes int8 quants] + ValueLayout.OfShort shortLayout = ValueLayout.JAVA_SHORT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + ValueLayout.OfByte byteLayout = ValueLayout.JAVA_BYTE; + + // element-wise copy and unpack from MemorySegment to HalfFloatArray scales and Int8Array quants + // use parallel streams and unroll inner loop for better performance + IntStream.range(0, numBlocks) + .parallel() + .forEach(block -> { + // TODO: use GGML type method for the 34L size + long blockOffset = block * 34L; // 34 bytes per block + + // read fp16 scale (first 2 bytes of block) + short scaleRaw = q8Segment.get(shortLayout, blockOffset); + scales.set(block, new HalfFloat(scaleRaw)); + int blockStart = block * 32; + + // read 32 int8 quantized values (remaining bytes of block) + // TODO: use GGML type method for the 32 size + for (int i = 0; i < 32; i += 4) { + // unroll inner loop for better performance + byte q0 = q8Segment.get(byteLayout, blockOffset + 2 + i); + byte q1 = q8Segment.get(byteLayout, blockOffset + 2 + i + 1); + byte q2 = q8Segment.get(byteLayout, blockOffset + 2 + i + 2); + byte q3 = q8Segment.get(byteLayout, blockOffset + 2 + i + 3); + + quants.set(blockStart + i, q0); + quants.set(blockStart + i + 1, q1); + quants.set(blockStart + i + 2, q2); + quants.set(blockStart + i + 3, q3); + } + }); + + return new Q8_0TornadoTensor(size, scales, quants, q8Segment); + } + + /** + * Creates a Q8_0TornadoTensor formulated as FP32TornadoTensor object from a GGMLTensorEntry. + * NOTE: Hack implementation to comply with FP32 inference. + */ + public static FP32TornadoTensor createAsFP32(GGMLTensorEntry entry) { + if (entry.ggmlType() != GGMLType.Q8_0) { + throw new IllegalArgumentException("Expected Q8_0 tensor, got: " + entry.ggmlType() + " for tensor: " + entry.name()); + } + + int[] shape = entry.shape(); + int size = FloatTensor.numberOfElements(shape); + int numBlocks = size / GGMLType.Q8_0.getBlockSize(); + + if (size % GGMLType.Q8_0.getBlockSize() != 0) { + throw new IllegalArgumentException("Q8_0 tensor size must be multiple of " + GGMLType.Q8_0.getBlockSize() + ", got: " + size + " for tensor: " + entry.name()); + } + + // TODO: fix Q8_0 loading in tornado layoyt + // currently we end up to hack it by removing + // tornado header from memory segment + MemorySegment q8Segment = entry.memorySegment().asSlice(TornadoNativeArray.ARRAY_HEADER); + + // allocate the FloatArray to store the result + FloatArray floatArray = new FloatArray(size); + + // unpack Q8_0 blocks: [2 bytes fp16 scale][32 bytes int8 quants] + ValueLayout.OfShort shortLayout = ValueLayout.JAVA_SHORT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + ValueLayout.OfByte byteLayout = ValueLayout.JAVA_BYTE; + + // element-wise dequantization and copy from MemorySegment to FloatArray + // use parallel streams and unroll inner loop for better performance + IntStream.range(0, numBlocks) + .parallel() + .forEach(block -> { + // TODO: use GGML type method for the 34L size + long blockOffset = block * 34L; // 34 bytes per block + + // read fp16 scale (first 2 bytes of block) and convert to float + short scaleRaw = q8Segment.get(shortLayout, blockOffset); + float scale = Float.float16ToFloat(scaleRaw); + int blockStart = block * 32; + + // read 32 int8 quantized values (remaining bytes of block) + // TODO: use GGML type method for the 32 size + for (int i = 0; i < 32; i += 4) { + // unroll inner loop for better performance + byte q0 = q8Segment.get(byteLayout, blockOffset + 2 + i); + byte q1 = q8Segment.get(byteLayout, blockOffset + 2 + i + 1); + byte q2 = q8Segment.get(byteLayout, blockOffset + 2 + i + 2); + byte q3 = q8Segment.get(byteLayout, blockOffset + 2 + i + 3); + + floatArray.set(blockStart + i, q0 * scale); + floatArray.set(blockStart + i + 1, q1 * scale); + floatArray.set(blockStart + i + 2, q2 * scale); + floatArray.set(blockStart + i + 3, q3 * scale); + } + }); + + return new FP32TornadoTensor(floatArray); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tensor/tornado/TornadoTensor.java b/src/main/java/org/beehive/gpullama3/tensor/tornado/TornadoTensor.java new file mode 100644 index 00000000..30ae9d15 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tensor/tornado/TornadoTensor.java @@ -0,0 +1,51 @@ +package org.beehive.gpullama3.tensor.tornado; + +import org.beehive.gpullama3.tensor.GGMLType; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.Int8Array; + +/** + * Base class for TornadoVM-compatible tensor types. + * These tensors wrap TornadoVM native arrays for GPU execution. + */ +public abstract class TornadoTensor { + + public abstract GGMLType type(); + + /** + * Get as FloatArray (for F32 tensors). + * + * @throws UnsupportedOperationException if not F32 + */ + public FloatArray asFloatArray() { + throw new UnsupportedOperationException("Not a FloatArray tensor: " + this.getClass().getSimpleName()); + } + + /** + * Get as HalfFloatArray (for F16 tensors). + * + * @throws UnsupportedOperationException if not F16 + */ + public HalfFloatArray asHalfFloatArray() { + throw new UnsupportedOperationException("Not a HalfFloatArray tensor: " + this.getClass().getSimpleName()); + } + + /** + * Get quantized scales (for Q8_0 tensors). + * + * @throws UnsupportedOperationException if not quantized + */ + public HalfFloatArray getScales() { + throw new UnsupportedOperationException("Not a quantized tensor: " + this.getClass().getSimpleName()); + } + + /** + * Get quantized values (for Q8_0 tensors). + * + * @throws UnsupportedOperationException if not quantized + */ + public Int8Array getQuants() { + throw new UnsupportedOperationException("Not a quantized tensor: " + this.getClass().getSimpleName()); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/impl/LlamaTokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/LlamaTokenizer.java similarity index 86% rename from src/main/java/org/beehive/gpullama3/tokenizer/impl/LlamaTokenizer.java rename to src/main/java/org/beehive/gpullama3/tokenizer/LlamaTokenizer.java index 9575ff76..36a78f1e 100644 --- a/src/main/java/org/beehive/gpullama3/tokenizer/impl/LlamaTokenizer.java +++ b/src/main/java/org/beehive/gpullama3/tokenizer/LlamaTokenizer.java @@ -1,10 +1,15 @@ -package org.beehive.gpullama3.tokenizer.impl; +package org.beehive.gpullama3.tokenizer; -import org.beehive.gpullama3.core.types.Pair; -import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary; +import org.beehive.gpullama3.auxiliary.Pair; import java.nio.charset.StandardCharsets; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -13,19 +18,18 @@ /** * GPT-2-style BPE tokenizer (even though it's called "llama") with an explicit merges list. *

- * BPE (Byte Pair Encoding): - * A sub-word tokenization algorithm that iteratively merges the most frequent pairs of symbols in a corpus to build a vocabulary of common character sequences. + * BPE (Byte Pair Encoding): A sub-word tokenization algorithm that iteratively merges the most frequent pairs of symbols in a corpus to build a vocabulary of common character sequences. *

- * GPT-2-style tokenization: - * Applies BPE at the byte level, ensuring all UTF-8 inputs are representable and using tokens that preserve leading spaces (e.g., 'Ġthe'). + * GPT-2-style tokenization: Applies BPE at the byte level, ensuring all UTF-8 inputs are representable and using tokens that preserve leading spaces (e.g., 'Ġthe'). *

- * Explicit merges list: - * A fixed sequence of learned merge rules that deterministically reconstructs the tokenizer’s vocabulary during inference without retraining. + * Explicit merges list: A fixed sequence of learned merge rules that deterministically reconstructs the tokenizer’s vocabulary during inference without retraining. *

* Based on minbpe, algorithmically follows along the * GPT 2 tokenizer */ public class LlamaTokenizer implements Tokenizer { + static final Map BYTE_ENCODER = bytesToUnicode(); + static final Map BYTE_DECODER = BYTE_ENCODER.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); private static final String LLAMA_3_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; // general fields private final Pattern compiledPattern; @@ -34,28 +38,6 @@ public class LlamaTokenizer implements Tokenizer { private final Map, Integer> merges; private final Map specialTokens; - public String regexPattern() { - if (compiledPattern == null) { - return null; - } - return compiledPattern.pattern(); - } - - @Override - public Map getSpecialTokens() { - return specialTokens; - } - - @Override - public boolean isSpecialToken(int tokenIndex) { - return specialTokens.containsValue(tokenIndex); - } - - @Override - public boolean shouldDisplayToken(int token) { - return !isSpecialToken(token); - } - public LlamaTokenizer(Map metadata, Vocabulary vocabulary) { // load from metadata String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges"); @@ -83,16 +65,85 @@ public LlamaTokenizer(Map metadata, Vocabulary vocabulary) { } } + private static List findAll(Pattern pattern, String text) { + List allMatches = new ArrayList<>(); + Matcher matcher = pattern.matcher(text); + while (matcher.find()) { + allMatches.add(matcher.group()); + } + return allMatches; + } + + private static List merge(List ids, Pair pair, int idx) { + List newids = new ArrayList<>(); + int i = 0; + while (i < ids.size()) { + // if not at the very last position AND the pair matches, replace it + if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) { + newids.add(idx); + i += 2; + } else { + newids.add(ids.get(i)); + i += 1; + } + } + return newids; + } + + /** + * Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if + * you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. + * To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. + */ + private static Map bytesToUnicode() { + List bs = new ArrayList<>(); + IntStream.rangeClosed('!', '~').forEach(bs::add); + IntStream.rangeClosed('¡', '¬').forEach(bs::add); + IntStream.rangeClosed('®', 'ÿ').forEach(bs::add); + + List cs = new ArrayList<>(bs); + int n = 0; + for (int b = 0; b < 256; ++b) { + if (!bs.contains(b)) { + bs.add(b); + cs.add(256 + n); + n += 1; + } + } + + // return dict(zip(bs, cs)) + return IntStream.range(0, bs.size()).boxed().collect(Collectors.toMap(bs::get, cs::get)); + } + + public String regexPattern() { + if (compiledPattern == null) { + return null; + } + return compiledPattern.pattern(); + } + + @Override + public Map getSpecialTokens() { + return specialTokens; + } + + @Override + public boolean isSpecialToken(int tokenIndex) { + return specialTokens.containsValue(tokenIndex); + } + + @Override + public boolean shouldDisplayToken(int token) { + return !isSpecialToken(token); + } + private int[] encodeImpl(String text) { return encode(text, Set.of()).stream().mapToInt(i -> i).toArray(); } /** - * Unlike {@link #encodeOrdinary(String)}, this function handles special tokens. - * allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens - * if none_raise, then an error is raised if any special token is encountered in text - * this is the default tiktoken behavior right now as well - * any other behavior is either annoying, or a major footgun. + * Unlike {@link #encodeOrdinary(String)}, this function handles special tokens. allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens if none_raise, then an error is + * raised if any special token is encountered in text this is the default tiktoken behavior right now as well any other behavior is either annoying, or a major footgun. */ public List encode(String text, Set allowedSpecial) { // decode the user desire w.r.t. handling of special tokens @@ -108,10 +159,7 @@ public List encode(String text, Set allowedSpecial) { // based on the occurrence of any exact match with any of the special tokens // we can use re.split for this. note that surrounding the pattern with () // makes it into a capturing group, so the special tokens will be included - String specialPattern = special - .stream() - .map(Pattern::quote) - .collect(Collectors.joining("|", "(", ")")); + String specialPattern = special.stream().map(Pattern::quote).collect(Collectors.joining("|", "(", ")")); String[] specialChunks = text.split(specialPattern); // now all the special characters are separated from the rest of the text @@ -129,15 +177,6 @@ public List encode(String text, Set allowedSpecial) { return ids; } - private static List findAll(Pattern pattern, String text) { - List allMatches = new ArrayList<>(); - Matcher matcher = pattern.matcher(text); - while (matcher.find()) { - allMatches.add(matcher.group()); - } - return allMatches; - } - /** * Encoding that ignores any special tokens. */ @@ -189,22 +228,6 @@ private List encodeChunk(String chunk) { return ids; } - private static List merge(List ids, Pair pair, int idx) { - List newids = new ArrayList<>(); - int i = 0; - while (i < ids.size()) { - // if not at the very last position AND the pair matches, replace it - if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) { - newids.add(idx); - i += 2; - } else { - newids.add(ids.get(i)); - i += 1; - } - } - return newids; - } - public String decodeImpl(List tokens) { StringBuilder sb = new StringBuilder(); for (int token : tokens) { @@ -214,38 +237,6 @@ public String decodeImpl(List tokens) { return sb.toString(); } - /** - * Returns list of utf-8 byte and a corresponding list of unicode strings. - * The reversible bpe codes work on unicode strings. - * This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - * When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - * This is a significant percentage of your normal, say, 32K bpe vocab. - * To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - * And avoids mapping to whitespace/control characters the bpe code barfs on. - */ - private static Map bytesToUnicode() { - List bs = new ArrayList<>(); - IntStream.rangeClosed('!', '~').forEach(bs::add); - IntStream.rangeClosed('¡', '¬').forEach(bs::add); - IntStream.rangeClosed('®', 'ÿ').forEach(bs::add); - - List cs = new ArrayList<>(bs); - int n = 0; - for (int b = 0; b < 256; ++b) { - if (!bs.contains(b)) { - bs.add(b); - cs.add(256 + n); - n += 1; - } - } - - // return dict(zip(bs, cs)) - return IntStream.range(0, bs.size()).boxed().collect(Collectors.toMap(bs::get, cs::get)); - } - - static final Map BYTE_ENCODER = bytesToUnicode(); - static final Map BYTE_DECODER = BYTE_ENCODER.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); - public int[] encode(String text) { StringBuilder sb = new StringBuilder(); byte[] bytes = text.getBytes(StandardCharsets.UTF_8); diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/impl/MistralTokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/MistralTokenizer.java similarity index 85% rename from src/main/java/org/beehive/gpullama3/tokenizer/impl/MistralTokenizer.java rename to src/main/java/org/beehive/gpullama3/tokenizer/MistralTokenizer.java index c4264a1b..940318f9 100644 --- a/src/main/java/org/beehive/gpullama3/tokenizer/impl/MistralTokenizer.java +++ b/src/main/java/org/beehive/gpullama3/tokenizer/MistralTokenizer.java @@ -1,9 +1,12 @@ -package org.beehive.gpullama3.tokenizer.impl; - -import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary; +package org.beehive.gpullama3.tokenizer; import java.nio.charset.StandardCharsets; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -11,18 +14,12 @@ /** * TikToken-style BPE tokenizer with byte fallback. *

- * TikToken-style: - * A Byte Pair Encoding (BPE) strategy that converts text to UTF-8 bytes. - * Frequent pairs of bytes (or tokens) are merged according to a learned vocabulary. - * This reduces long words into common subwords or whole-word tokens. - * If a word or character isn't found, it falls back to byte-level tokens. + * TikToken-style: A Byte Pair Encoding (BPE) strategy that converts text to UTF-8 bytes. Frequent pairs of bytes (or tokens) are merged according to a learned vocabulary. This reduces long words into + * common subwords or whole-word tokens. If a word or character isn't found, it falls back to byte-level tokens. *

- * Byte fallback: - * A fail-safe mechanism. - * It ensures every byte has a token, so any input (even unknown words, misspellings, foreign languages, emojis, or binary) can be tokenized. - * If a token is not found in the merges or vocabulary, it will fall back to the individual byte. - * Each byte is wrapped as a special token like <0xF0> — these are part of the tokenizer’s extended vocabulary. - * This guarantees reversibility: every string can be tokenized and decoded back exactly. + * Byte fallback: A fail-safe mechanism. It ensures every byte has a token, so any input (even unknown words, misspellings, foreign languages, emojis, or binary) can be tokenized. If a token is not + * found in the merges or vocabulary, it will fall back to the individual byte. Each byte is wrapped as a special token like <0xF0> — these are part of the tokenizer’s extended vocabulary. This + * guarantees reversibility: every string can be tokenized and decoded back exactly. */ public class MistralTokenizer implements Tokenizer { private static final String MISTRAL_PATTERN = "\\S+|\\s+"; @@ -34,6 +31,26 @@ public class MistralTokenizer implements Tokenizer { private final int[] tokenType; private final int byte0; + // @formatter:off + public MistralTokenizer(Map metadata, Vocabulary vocabulary) { + // load from metadata + int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type"); + List specialTokensList = IntStream.range(0, vocabulary.size()).filter(t -> tokenTypes[t] != 1 && tokenTypes[t] != 6).boxed().toList(); + Map specialTokens = + IntStream.range(0, specialTokensList.size()) + .boxed() + .collect(Collectors.toMap( + t -> vocabulary.get(t), + t -> t) + ); + // init tokenizer object fields + this.vocabulary = vocabulary; + this.compiledPattern = null; + this.specialTokens = new HashMap<>(specialTokens); + this.tokenType = tokenTypes; + this.byte0 = vocabulary.getIndex("<0x00>").orElseThrow(); + } + public String regexPattern() { if (compiledPattern == null) { return null; @@ -60,26 +77,6 @@ public boolean shouldDisplayToken(int token) { public int getTokenType(int tokenIndex) { return tokenType[tokenIndex]; } - - // @formatter:off - public MistralTokenizer(Map metadata, Vocabulary vocabulary) { - // load from metadata - int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type"); - List specialTokensList = IntStream.range(0, vocabulary.size()).filter(t -> tokenTypes[t] != 1 && tokenTypes[t] != 6).boxed().toList(); - Map specialTokens = - IntStream.range(0, specialTokensList.size()) - .boxed() - .collect(Collectors.toMap( - t -> vocabulary.get(t), - t -> t) - ); - // init tokenizer object fields - this.vocabulary = vocabulary; - this.compiledPattern = null; - this.specialTokens = new HashMap<>(specialTokens); - this.tokenType = tokenTypes; - this.byte0 = vocabulary.getIndex("<0x00>").orElseThrow(); - } // @formatter:on private List encodeImpl(String text) { diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/impl/Phi3Tokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/Phi3Tokenizer.java similarity index 97% rename from src/main/java/org/beehive/gpullama3/tokenizer/impl/Phi3Tokenizer.java rename to src/main/java/org/beehive/gpullama3/tokenizer/Phi3Tokenizer.java index e8e12d92..4b5167c0 100644 --- a/src/main/java/org/beehive/gpullama3/tokenizer/impl/Phi3Tokenizer.java +++ b/src/main/java/org/beehive/gpullama3/tokenizer/Phi3Tokenizer.java @@ -1,7 +1,6 @@ -package org.beehive.gpullama3.tokenizer.impl; +package org.beehive.gpullama3.tokenizer; -import org.beehive.gpullama3.core.types.Pair; -import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary; +import org.beehive.gpullama3.auxiliary.Pair; import java.io.ByteArrayOutputStream; import java.nio.charset.StandardCharsets; diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/impl/Qwen3Tokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/Qwen3Tokenizer.java similarity index 98% rename from src/main/java/org/beehive/gpullama3/tokenizer/impl/Qwen3Tokenizer.java rename to src/main/java/org/beehive/gpullama3/tokenizer/Qwen3Tokenizer.java index 0f8751fb..077dd536 100644 --- a/src/main/java/org/beehive/gpullama3/tokenizer/impl/Qwen3Tokenizer.java +++ b/src/main/java/org/beehive/gpullama3/tokenizer/Qwen3Tokenizer.java @@ -1,8 +1,7 @@ -package org.beehive.gpullama3.tokenizer.impl; +package org.beehive.gpullama3.tokenizer; import org.beehive.gpullama3.auxiliary.Utf8Mask; -import org.beehive.gpullama3.core.types.Pair; -import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary; +import org.beehive.gpullama3.auxiliary.Pair; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -18,13 +17,14 @@ import java.util.stream.IntStream; public class Qwen3Tokenizer implements Tokenizer { + static final Map BYTE_ENCODER = bytesToUnicode(); + static final Map BYTE_DECODER = BYTE_ENCODER.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); private final static String QWEN3_PATTERN = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; private final Pattern compiledPattern; private final Vocabulary vocabulary; private final Map, Integer> merges; private final Map specialTokens; private final int[] tokenTypes; - /** buffer to store incomplete UTF-8 sequence */ private final byte[] bufUtf8 = new byte[4]; /** index in UTF-8 buffer */ @@ -32,38 +32,6 @@ public class Qwen3Tokenizer implements Tokenizer { /** current UTF-8 mask */ private Utf8Mask currUtf8Mask; - @Override - public String regexPattern() { - if (compiledPattern == null) { - return null; - } - return compiledPattern.pattern(); - } - - @Override - public Map getSpecialTokens() { - return specialTokens; - } - - @Override - public boolean isSpecialToken(int tokenIndex) { - return specialTokens.containsValue(tokenIndex); - } - - @Override - public boolean shouldDisplayToken(int token) { - int tokenType = getTokenType(token); - // tokenType 4 allows the display of reasoning ( ... <\think> ) - return tokenType == 1 || tokenType == 4 || tokenType == 6; - } - - public int getTokenType(int tokenIndex) { - if (tokenTypes == null) { - throw new IllegalStateException("Qwen3Tokenizer hasn't been constructed using tokenTypes"); - } - return tokenTypes[tokenIndex]; - } - // @formatter:off public Qwen3Tokenizer(Map metadata, Vocabulary vocabulary, boolean isDeepSeekR1DistillQwen) { int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type"); @@ -106,11 +74,6 @@ public Qwen3Tokenizer(Map metadata, Vocabulary vocabulary, boole this.merges.put(pair, mergeIndex); } } - // @formatter:on - - private int[] encodeImpl(String text) { - return encode(text, Set.of()).stream().mapToInt(i -> i).toArray(); - } static List findAll(Pattern pattern, String text) { List allMatches = new ArrayList<>(); @@ -121,6 +84,92 @@ static List findAll(Pattern pattern, String text) { return allMatches; } + static List merge(List ids, Pair pair, int idx) { + List newids = new ArrayList<>(); + int i = 0; + while (i < ids.size()) { + // if not at the very last position AND the pair matches, replace it + if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) { + newids.add(idx); + i += 2; + } else { + newids.add(ids.get(i)); + i += 1; + } + } + return newids; + } + + /** + * Returns list of utf-8 byte and a corresponding list of unicode strings. + * The reversible bpe codes work on unicode strings. + * This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + * When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + * This is a significant percentage of your normal, say, 32K bpe vocab. + * To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + * And avoids mapping to whitespace/control characters the bpe code barfs on. + */ + static Map bytesToUnicode() { + List bs = new ArrayList<>(); + IntStream.rangeClosed('!', '~').forEach(bs::add); + IntStream.rangeClosed('¡', '¬').forEach(bs::add); + IntStream.rangeClosed('®', 'ÿ').forEach(bs::add); + + List cs = new ArrayList<>(bs); + int n = 0; + for (int b = 0; b < 256; ++b) { + if (!bs.contains(b)) { + bs.add(b); + cs.add(256 + n); + n += 1; + } + } + + // return dict(zip(bs, cs)) + return IntStream.range(0, bs.size()) + .boxed() + .collect(Collectors.toMap(bs::get, cs::get)); + } + // @formatter:on + + @Override + public String regexPattern() { + if (compiledPattern == null) { + return null; + } + return compiledPattern.pattern(); + } + + @Override + public Map getSpecialTokens() { + return specialTokens; + } + + @Override + public boolean isSpecialToken(int tokenIndex) { + return specialTokens.containsValue(tokenIndex); + } + + @Override + public boolean shouldDisplayToken(int token) { + int tokenType = getTokenType(token); + // tokenType 4 allows the display of reasoning ( ... <\think> ) + return tokenType == 1 || tokenType == 4 || tokenType == 6; + } + + public int getTokenType(int tokenIndex) { + if (tokenTypes == null) { + throw new IllegalStateException("Qwen3Tokenizer hasn't been constructed using tokenTypes"); + } + return tokenTypes[tokenIndex]; + } + + private int[] encodeImpl(String text) { + return encode(text, Set.of()).stream().mapToInt(i -> i).toArray(); + } + + // @formatter:off + /** * Encoding that ignores any special tokens. */ @@ -135,6 +184,7 @@ public List encodeOrdinary(String text) { } return ids; } + // @formatter:on private Map, Integer> getStats(List ids) { Map, Integer> map = new HashMap<>(); @@ -172,58 +222,6 @@ private List encodeChunk(String chunk) { return ids; } - static List merge(List ids, Pair pair, int idx) { - List newids = new ArrayList<>(); - int i = 0; - while (i < ids.size()) { - // if not at the very last position AND the pair matches, replace it - if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) { - newids.add(idx); - i += 2; - } else { - newids.add(ids.get(i)); - i += 1; - } - } - return newids; - } - - // @formatter:off - /** - * Returns list of utf-8 byte and a corresponding list of unicode strings. - * The reversible bpe codes work on unicode strings. - * This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - * When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - * This is a significant percentage of your normal, say, 32K bpe vocab. - * To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - * And avoids mapping to whitespace/control characters the bpe code barfs on. - */ - static Map bytesToUnicode() { - List bs = new ArrayList<>(); - IntStream.rangeClosed('!', '~').forEach(bs::add); - IntStream.rangeClosed('¡', '¬').forEach(bs::add); - IntStream.rangeClosed('®', 'ÿ').forEach(bs::add); - - List cs = new ArrayList<>(bs); - int n = 0; - for (int b = 0; b < 256; ++b) { - if (!bs.contains(b)) { - bs.add(b); - cs.add(256 + n); - n += 1; - } - } - - // return dict(zip(bs, cs)) - return IntStream.range(0, bs.size()) - .boxed() - .collect(Collectors.toMap(bs::get, cs::get)); - } - // @formatter:on - - static final Map BYTE_ENCODER = bytesToUnicode(); - static final Map BYTE_DECODER = BYTE_ENCODER.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); - public int[] encode(String text) { StringBuilder sb = new StringBuilder(); byte[] bytes = text.getBytes(StandardCharsets.UTF_8); @@ -290,8 +288,6 @@ public List encodeAsList(String text) { return Arrays.stream(encode(text)).boxed().toList(); } - - public String decodeImpl(List tokens) { StringBuilder sb = new StringBuilder(); for (int token : tokens) { diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/impl/Tokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/Tokenizer.java similarity index 89% rename from src/main/java/org/beehive/gpullama3/tokenizer/impl/Tokenizer.java rename to src/main/java/org/beehive/gpullama3/tokenizer/Tokenizer.java index 8419019d..ec67c5f5 100644 --- a/src/main/java/org/beehive/gpullama3/tokenizer/impl/Tokenizer.java +++ b/src/main/java/org/beehive/gpullama3/tokenizer/Tokenizer.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.tokenizer.impl; +package org.beehive.gpullama3.tokenizer; import java.util.HexFormat; import java.util.List; @@ -6,27 +6,6 @@ import java.util.Set; public interface Tokenizer { - String regexPattern(); - - Map getSpecialTokens(); - - boolean isSpecialToken(int tokenIndex); - - /** - * Determines if a token should be displayed during streaming output. - * This filters out special tokens, control characters, or other non-displayable content. - * - * @param token the token to check - * @return true if the token should be displayed to the user, false otherwise - */ - boolean shouldDisplayToken(int token); - - List encode(String text, Set allowedSpecial); - - List encodeAsList(String text); - - String decode(List tokens); - // Utility method for all tokenizers, implemented as static. static String replaceControlCharacters(int[] codePoints) { // we don't want to print control characters @@ -49,5 +28,26 @@ static String replaceControlCharacters(String str) { return replaceControlCharacters(str.codePoints().toArray()); } + String regexPattern(); + + Map getSpecialTokens(); + + boolean isSpecialToken(int tokenIndex); + + /** + * Determines if a token should be displayed during streaming output. This filters out special tokens, control characters, or other non-displayable content. + * + * @param token + * the token to check + * @return true if the token should be displayed to the user, false otherwise + */ + boolean shouldDisplayToken(int token); + + List encode(String text, Set allowedSpecial); + + List encodeAsList(String text); + + String decode(List tokens); + } diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/vocabulary/Vocabulary.java b/src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java similarity index 98% rename from src/main/java/org/beehive/gpullama3/tokenizer/vocabulary/Vocabulary.java rename to src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java index 474b4b77..1a867569 100644 --- a/src/main/java/org/beehive/gpullama3/tokenizer/vocabulary/Vocabulary.java +++ b/src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.tokenizer.vocabulary; +package org.beehive.gpullama3.tokenizer; import java.util.Arrays; import java.util.Map; @@ -18,15 +18,6 @@ public Vocabulary(String[] vocabulary, float[] scores) { } // @formatter:on - public String get(int tokenIndex) { - return tokens[tokenIndex]; - } - - public OptionalInt getIndex(String token) { - Integer value = tokenToIndex.get(token); - return value != null ? OptionalInt.of(value) : OptionalInt.empty(); - } - public static Vocabulary loadLlamaVocabulary(Map metadata) { String[] tokens = (String[]) metadata.get("tokenizer.ggml.tokens"); return new Vocabulary(tokens, null); @@ -51,6 +42,15 @@ public static Vocabulary loadPhi3Vocabulary(Map metadata) { return new Vocabulary(tokens, scores); } + public String get(int tokenIndex) { + return tokens[tokenIndex]; + } + + public OptionalInt getIndex(String token) { + Integer value = tokenToIndex.get(token); + return value != null ? OptionalInt.of(value) : OptionalInt.empty(); + } + public int size() { return tokens.length; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/GPULLlama3TypeException.java b/src/main/java/org/beehive/gpullama3/tornadovm/GPULLlama3TypeException.java new file mode 100644 index 00000000..78962a2c --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/GPULLlama3TypeException.java @@ -0,0 +1,7 @@ +package org.beehive.gpullama3.tornadovm; + +public class GPULLlama3TypeException extends IllegalArgumentException { + public GPULLlama3TypeException(String message) { + super(message); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/GenericLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/GenericLayerPlanner.java new file mode 100644 index 00000000..5a151212 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/GenericLayerPlanner.java @@ -0,0 +1,14 @@ +package org.beehive.gpullama3.tornadovm; + +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; + +import java.util.List; + +public interface GenericLayerPlanner { + + List getImmutableTaskGraphs(); + + GridScheduler getGridScheduler(); + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlanner.java deleted file mode 100644 index 6cfdb821..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlanner.java +++ /dev/null @@ -1,355 +0,0 @@ -package org.beehive.gpullama3.tornadovm; - -import org.beehive.gpullama3.auxiliary.Tuple2; -import org.beehive.gpullama3.inference.state.Phi3State; -import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.phi3.Phi3Configuration; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; -import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.WorkerGrid; -import uk.ac.manchester.tornado.api.WorkerGrid1D; -import uk.ac.manchester.tornado.api.enums.DataTransferMode; - -import java.util.ArrayList; -import java.util.List; - -public class Phi3TornadoVMLayerPlanner extends TornadoVMLayerPlanner { - - /** - * Constructs a TornadoVMLayerPlanner for the given Llama model. - * - * @param state - * The state object containing model tensors and buffers - * @param model - * The Llama model instance containing configuration and weights - */ - public Phi3TornadoVMLayerPlanner(Phi3State state, Model model) { - super(state, model); - } - - public Tuple2, GridScheduler> setupTornadoForwardPlanLayered() { - List taskGraphs = new ArrayList<>(); - - state.temp.init(0.0f); - state.tempFFN.init(0.0f); - state.tempLogits.init(0.0f); - final int opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); - - // @formatter:off - TaskGraph activationUpdate = new TaskGraph("activationUpdate") - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX) - .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX) - .persistOnDevice(state.wrapX); - taskGraphs.add(activationUpdate.snapshot()); - - TaskGraph unifiedLayer = null; - for (int layerIndex = 0; layerIndex < config.numberOfLayers(); layerIndex++) { - unifiedLayer = new TaskGraph("layer_" + layerIndex); - unifiedLayer.consumeFromDevice(state.wrapX); - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, - weights.rms_att_weightLayered[layerIndex], - weights.wqkvLayered[layerIndex], - weights.woLayered[layerIndex], - weights.rms_ffn_weightLayered[layerIndex], - weights.wDownLayered[layerIndex], - weights.wUpLayered[layerIndex] - ); - unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer.task("reductionsOneBlock" , TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, - state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, - state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp) - .task("qkvmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQkv, - weights.wqkvLayered[layerIndex], config.dim(), opSize, LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("splitQKV", TransformerComputeKernelsLayered::splitQKV, - state.wrapQkv, state.wrapQ, state.wrapK, state.wrapV, - config.dim(), config.headSize() * config.numberOfKeyValueHeads()) - .task("rope", TransformerComputeKernelsLayered::ropeRotationPhi3,context, - state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), - config.headSize()) - .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, - state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength()) - .task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, context, - state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), - state.positionHolder, layerIndex, config.contextLength()) - .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, - state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, - state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, - state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN) - .task("wGateUp", TransformerComputeKernelsLayered::matrixVectorGeneric, context, - state.wrapXb, state.wrapHb, weights.wUpLayered[layerIndex], config.dim(), 2 * config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("gateUpSiLU", TransformerComputeKernelsLayered::splitGateUpAndSiLU, - state.wrapHb, state.wrapHbG, state.wrapHbU, config.hiddenDim()) - .task("wDown", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, - state.wrapHbU, state.wrapX, weights.wDownLayered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .persistOnDevice( - state.wrapX - ); - taskGraphs.add(unifiedLayer.snapshot()); - } - - TaskGraph lastUnifiedLayer = unifiedLayer; - TaskGraph logits = new TaskGraph("logits") - .consumeFromDevice(lastUnifiedLayer.getTaskGraphName(), - state.wrapX - ) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, - state.tempLogits - ) - .transferToDevice(DataTransferMode.FIRST_EXECUTION, - context, - state.wrapLogits, - weights.wclsHalfFloat, - weights.rms_final_weight_as_floatArray - ) - .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, - state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, - weights.rms_final_weight_as_floatArray, state.tempLogits); - logits = configureQuantizedMatrixVectorFinalWeight(logits); - logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); - taskGraphs.add(logits.snapshot()); - // @formatter:on - - return new Tuple2<>(taskGraphs, setupGridSchedulersLayered()); - } - - // @formatter:off - /** - * Configures the final projection layer in the task graph based on weight quantization type. - * - * This method adds a "projection" task to compute the final logits by performing a - * matrix-vector multiplication between the model's output embeddings and the classifier - * weights (wcls). The computation kernel used depends on the quantization format. - * - * Supported quantization types: - * - Q8_0: 8-bit quantization with uniform scaling per 32-element block - * - Q4_0: 4-bit quantization with uniform scaling per 32-element block - * - * The task multiplies: - * - weights.wclsByteArray: Quantized classifier weights (vocab_size x dim) - * - state.wrapX: Current layer output (dim) - * - Result: state.wrapLogits: Raw logits (vocab_size) - * - * @param logits The existing task graph to extend with the projection operation - * @return The modified task graph with the projection task added - * @throws UnsupportedOperationException If weights.weightType is not Q8_0 or Q4_0 - */ - // @formatter:on - protected TaskGraph configureQuantizedMatrixVectorFinalWeight(TaskGraph logits) { - switch (weights.getWeightType()) { - case F16: - case Q8_0: - case Q4_0: - logits.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, // - context, state.wrapX, state.wrapLogits, weights.wclsHalfFloat, // - config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); // - break; - default: - throw new UnsupportedOperationException("Unsupported weight quantization type: " + weights.getWeightType() + ". Only Q8_0 and Q4_0 are supported."); - } - return logits; - } - - /** - * Configures data transfer operations for a specific layer in the neural network task graph. - * - * This method manages GPU memory transfers with optimized data movement strategies: This optimization pattern minimizes data movement by: 1. Using one-time transfers for static data 2. Reusing - * intermediate results already on GPU from previous layers 3. Only transferring // dynamic data that changes per execution - * - * @param unifiedLayer - * The task graph representing this layer's operations - * @param layerIndex - * Index of the current layer (0-based) - * @return The configured task graph with appropriate data transfer operations - */ - protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { - // First layer: Transfer initial data to device (one-time transfer) - if (layerIndex == 0) { - // Transfer all attention-related data: query, key, value matrices and their caches - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); // - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // - context, state.wrapXb, state.wrapXb2, // - state.wrapQ, state.wrapK, state.wrapV, // - state.wrapKeyCache, state.wrapValueCache, // - state.wrapAtt, state.wrapHb, // - state.wrapHbG, state.wrapHbU, state.wrapQkv); // - } else { - // Subsequent layers: Consume data already on device from previous layer - unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // - state.wrapQ, state.wrapK, state.wrapV, // - state.wrapKeyCache, state.wrapValueCache, // - state.wrapAtt, state.wrapHb, // - state.positionHolder, // / - state.wrapHbG, state.wrapHbU, state.wrapQkv); - } - return unifiedLayer; - } - - // @formatter:off - /** - * Sets up the grid scheduler configuration for a layered neural network forward pass. - * - * This method creates and configures worker grids for different types of GPU operations - * in the transformer/ML model pipeline. Each worker grid defines how work should be - * distributed across GPU threads (OpenCL work-items or CUDA threads). - * - * The method creates several worker profiles: - * - Single thread operations (activation updates) - * - RoPE (Rotary Position Embedding) operations - * - Matrix multiplications with different dimensions - * - RMS normalization operations - * - Parallel attention computations - * - Cache copying operations - * - Vocabulary projections - * - * Each worker grid maps to equivalent OpenCL NDRange or CUDA grid/block configurations: - * - setGlobalWork() ≈ OpenCL global_work_size ≈ CUDA grid dimensions × block dimensions - * - setLocalWork() ≈ OpenCL local_work_size ≈ CUDA block dimensions - * - * @return GridScheduler configured with all necessary worker grids for the model layers - */ - // @formatter:on - private GridScheduler setupGridSchedulersLayered() { - GridScheduler tornadoForwardScheduler = new GridScheduler(); - - // Single worker for tasks running with a single thread - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[1,1,1], localWorkSize=[1,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid singleWorker = new WorkerGrid1D(1); - singleWorker.setGlobalWork(1, 1, 1); - singleWorker.setLocalWork(1, 1, 1); - - // config.dim / 2 Worker for RoPE - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim/2,1,1], localWorkSize=[128,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid ropeWorker = new WorkerGrid1D(config.dim() / 2); - ropeWorker.setGlobalWork(config.dim() / 2, 1, 1); - ropeWorker.setLocalWork(128, 1, 1); - - // config.dim Worker for Row major access - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) - // CUDA equivalent: kernel<<>> - int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configDimRowMajorGlobalWorker = new WorkerGrid1D(configDimRowMajorGlobal); - configDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - final int opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); - - int qkvmatmulDimRowMajorGlobal = opSize * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid qkvDimRowMajorGlobalWorker = new WorkerGrid1D(qkvmatmulDimRowMajorGlobal); - qkvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - // config.kvDim Worker for Row major access - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.kvDim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) - // CUDA equivalent: kernel<<>> - int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configKvDimRowMajorGlobalWorker = new WorkerGrid1D(configKvDimRowMajorGlobal); - configKvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - // config.hiddenDim * 32 Worker for Row major access - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.hiddenDim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) - // CUDA equivalent: kernel<<>> - int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configHiddenDimRowMajorWorker = new WorkerGrid1D(configHiddenDimRowMajor); - configHiddenDimRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - int wgetUPDimRowMajor = 2 * config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid wgetHiddenDimRowMajorWorker = new WorkerGrid1D(wgetUPDimRowMajor); - wgetHiddenDimRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - // RMSNorm worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[256,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim()); - rmsNormWorker.setGlobalWork(config.dim(), 1, 1); // Set global work size to total dimension - rmsNormWorker.setLocalWork(256, 1, 1); // Set local work size to 256 (standard efficient size) - - // Parallel attention worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.numberOfHeads,1,1], localWorkSize=[4,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads()); - // the global group work size is numberOfHeads * localWorkGroupSize, where the localWorkGroupSize is currently 4 - parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * 8, 1, 1); - parallelAttentionWorker.setLocalWork(8, 1, 1); // Set local work size to 4 (for parallel attention) - - // Copy to caches worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[128,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid copyToCachesWorker = new WorkerGrid1D(config.kvDim()); - copyToCachesWorker.setGlobalWork(config.dim(), 1, 1); - copyToCachesWorker.setLocalWork(128, 1, 1); // Set local work size to 32 (for copying to caches) - - // Q copy worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[128,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid copyQWorker = new WorkerGrid1D(config.dim()); - copyQWorker.setGlobalWork(config.dim(), 1, 1); - copyQWorker.setLocalWork(128, 1, 1); - - // K copy worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[kvSize,1,1], localWorkSize=[128,1,1]) - // CUDA equivalent: kernel<<>> - int kvSize = config.headSize() * config.numberOfKeyValueHeads(); - WorkerGrid copyKWorker = new WorkerGrid1D(kvSize); - copyKWorker.setGlobalWork(kvSize, 1, 1); - copyKWorker.setLocalWork(128, 1, 1); - - // V copy worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[kvSize,1,1], localWorkSize=[128,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid copyVWorker = new WorkerGrid1D(kvSize); - copyVWorker.setGlobalWork(kvSize, 1, 1); - copyVWorker.setLocalWork(128, 1, 1); - - WorkerGrid hiddenDimWorker = new WorkerGrid1D(config.hiddenDim()); - hiddenDimWorker.setGlobalWork(config.hiddenDim(), 1, 1); - hiddenDimWorker.setLocalWork(128, 1, 1); - - WorkerGrid splitGateUpSiLUWorker = new WorkerGrid1D(config.hiddenDim()); - splitGateUpSiLUWorker.setGlobalWork(config.hiddenDim(), 1, 1); - splitGateUpSiLUWorker.setLocalWork(128, 1, 1); - - // Total work size is dimQ + 2*dimKV (same as opSize) - WorkerGrid splitQKVWorker = new WorkerGrid1D(opSize); - splitQKVWorker.setGlobalWork(opSize, 1, 1); - splitQKVWorker.setLocalWork(128, 1, 1); - - // Map workers to tasks - tornadoForwardScheduler.addWorkerGrid("activationUpdate.updateX", singleWorker); - for (int i = 0; i < config.numberOfLayers(); i++) { - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkvmatmul", qkvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".wDown", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".wGateUp", wgetHiddenDimRowMajorWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); - // New FFN tasks - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".gateUpSiLU", splitGateUpSiLUWorker); - } - - // Vocabulary worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.vocabularySize,1,1], localWorkSize=[16,1,1]) - // CUDA equivalent: kernel<<>> - int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; - WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); - vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); - - tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker); - tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker); - - return tornadoForwardScheduler; - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java deleted file mode 100644 index 1f9d547b..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java +++ /dev/null @@ -1,250 +0,0 @@ -package org.beehive.gpullama3.tornadovm; - -import org.beehive.gpullama3.auxiliary.Tuple2; -import org.beehive.gpullama3.inference.state.Qwen2State; -import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; -import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.WorkerGrid; -import uk.ac.manchester.tornado.api.WorkerGrid1D; -import uk.ac.manchester.tornado.api.WorkerGrid2D; -import uk.ac.manchester.tornado.api.enums.DataTransferMode; - -import java.util.ArrayList; -import java.util.List; - -public class Qwen2TornadoVMLayerPlanner extends TornadoVMLayerPlanner { - - /** - * Constructs a TornadoVMLayerPlanner for the given Qwen2 model. - * - * @param state - * The state object containing model tensors and buffers - * @param model - * The Qwen2 model instance containing configuration and weights - */ - public Qwen2TornadoVMLayerPlanner(Qwen2State state, Model model) { - super(state, model); - } - - @Override - public Tuple2, GridScheduler> setupTornadoForwardPlanLayered() { - List taskGraphs = new ArrayList<>(); - - state.temp.init(0.0f); - state.tempFFN.init(0.0f); - state.tempLogits.init(0.0f); - state.wrapLogits.init(0.0f); - - TaskGraph activationUpdate = new TaskGraph("activationUpdate") - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX) - .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX) - .persistOnDevice(state.wrapX); - taskGraphs.add(activationUpdate.snapshot()); - - TaskGraph unifiedLayer = null; - for (int layerIndex =0; layerIndex < config.numberOfLayers(); layerIndex++) { - unifiedLayer = new TaskGraph("layer_" + layerIndex); - unifiedLayer.consumeFromDevice(state.wrapX); - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, - //Copy-in weights per layer for batched-layered layout - weights.rms_att_weightLayered[layerIndex], - weights.wqLayered[layerIndex], - weights.wkLayered[layerIndex], - weights.wvLayered[layerIndex], - weights.woLayered[layerIndex], - weights.q_biasLayered[layerIndex], - weights.k_biasLayered[layerIndex], - weights.v_biasLayered[layerIndex], - weights.rms_ffn_weightLayered[layerIndex], - weights.w1Layered[layerIndex], - weights.w2Layered[layerIndex], - weights.w3Layered[layerIndex] - ); - unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - - unifiedLayer.task("reductionsOneBlock" , TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, - state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, - state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp) - .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, - state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, - state.wrapXb, state.wrapK, weights.wkLayered[layerIndex], config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, - state.wrapXb, state.wrapV, weights.wvLayered[layerIndex], config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("qbias", TransformerComputeKernelsLayered::addInPlace, state.wrapQ, weights.q_biasLayered[layerIndex], config.dim()) - .task("kbias", TransformerComputeKernelsLayered::addInPlace, state.wrapK, weights.k_biasLayered[layerIndex], config.kvDim()) - .task("vbias", TransformerComputeKernelsLayered::addInPlace, state.wrapV, weights.v_biasLayered[layerIndex], config.kvDim()) - .task("rope", Qwen3Kernels::ropeRotation,context, state.positionHolder, state.wrapQ, state.wrapK, config.numberOfKeyValueHeads(), - config.headSize()) - .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, - state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength()) - .task("parallel-attention", Qwen2Kernels::processHeadsFlashAttention, context, - state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), - state.positionHolder, layerIndex, config.contextLength()) - .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, - state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, - state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, - state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN) - .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, - state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex], weights.w3Layered[layerIndex], config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, - state.wrapHb, state.wrapX, weights.w2Layered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .persistOnDevice( - state.wrapX - ); - taskGraphs.add(unifiedLayer.snapshot()); - } - - TaskGraph lastUnifiedLayer = unifiedLayer; - TaskGraph logits = new TaskGraph("logits") - .consumeFromDevice(lastUnifiedLayer.getTaskGraphName(), - state.wrapX - ) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, - state.tempLogits - ) - .transferToDevice(DataTransferMode.FIRST_EXECUTION, - context, - state.wrapLogits, - weights.wclsHalfFloat, - weights.rms_final_weight_as_floatArray - ) - .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, - state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, - weights.rms_final_weight_as_floatArray, state.tempLogits); - logits = configureQuantizedMatrixVectorFinalWeight(logits); - logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); - taskGraphs.add(logits.snapshot()); - // @formatter:on - - return new Tuple2<>(taskGraphs, setupQwen2GridSchedulersLayeredNonNvidia()); - } - - @Override - public Tuple2, GridScheduler> setupTornadoForwardPlanLayeredNonNvidia() { - return setupTornadoForwardPlanLayered(); - } - - private GridScheduler setupQwen2GridSchedulersLayeredNonNvidia() { - //throw new UnsupportedOperationException("setupQwen2GridSchedulersLayeredNonNvidia Not supported yet."); - GridScheduler tornadoForwardScheduler = new GridScheduler(); - - // Single worker for tasks running with a single thread - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[1,1,1], localWorkSize=[1,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid singleWorker = new WorkerGrid1D(1); - singleWorker.setGlobalWork(1, 1, 1); - singleWorker.setLocalWork(1, 1, 1); - - // config.dim / 2 Worker for RoPE - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim/2,1,1], localWorkSize=[128,1,1]) - // CUDA equivalent: kernel<<>> - int h = config.numberOfHeads(); - int ic = config.headSize() / 2; - WorkerGrid ropeWorker = new WorkerGrid2D(h, ic); - ropeWorker.setGlobalWork(h, ic, 1); - ropeWorker.setLocalWork(1, 1, 1); - - // config.dim Worker for Row major access - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) - // CUDA equivalent: kernel<<>> - int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configDimRowMajorGlobalWorker = new WorkerGrid1D(configDimRowMajorGlobal); - configDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - // config.kvDim Worker for Row major access - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.kvDim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) - // CUDA equivalent: kernel<<>> - int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configKvDimRowMajorGlobalWorker = new WorkerGrid1D(configKvDimRowMajorGlobal); - configKvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - WorkerGrid qBiasWorker = new WorkerGrid1D(config.dim()); - qBiasWorker.setGlobalWork(config.dim(), 1, 1); - qBiasWorker.setLocalWork(config.dim() / 8, 1, 1); - WorkerGrid kvBiasWorker = new WorkerGrid1D(config.kvDim()); - kvBiasWorker.setGlobalWork(config.kvDim(), 1, 1); - kvBiasWorker.setLocalWork(32, 1, 1); - - // config.hiddenDim * 32 Worker for Row major access - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.hiddenDim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) - // CUDA equivalent: kernel<<>> - int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configHiddenDimRowMajorWorker = new WorkerGrid1D(configHiddenDimRowMajor); - configHiddenDimRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - // RMSNorm worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[256,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim()); - rmsNormWorker.setGlobalWork(config.dim(), 1, 1); // Set global work size to total dimension - rmsNormWorker.setLocalWork(32, 1, 1); // Set local work size to 256 (standard efficient size) - - // Parallel attention worker configuration - // Calculate optimal local work size based on head dimension - int optimalLocalSize = Math.min(config.headSize(), 64); // Start with 64 threads per head - if (config.headSize() % optimalLocalSize != 0) { - // Find largest divisor of headSize <= 64 - for (int size = 64; size >= 1; size--) { - if (config.headSize() % size == 0) { - optimalLocalSize = size; - break; - } - } - } - - WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads()); - parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * optimalLocalSize, 1, 1); - parallelAttentionWorker.setLocalWork(optimalLocalSize, 1, 1); - - // Copy to caches worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[128,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid copyToCachesWorker = new WorkerGrid1D(config.kvDim()); - copyToCachesWorker.setGlobalWork(config.kvDim(), 1, 1); - copyToCachesWorker.setLocalWork(32, 1, 1); // Set local work size to 32 (for copying to caches) - - // Map workers to tasks - tornadoForwardScheduler.addWorkerGrid("activationUpdate.updateX", singleWorker); - for (int i = 0; i < config.numberOfLayers(); i++) { - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qbias", qBiasWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kbias", kvBiasWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vbias", kvBiasWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); - } - - // Vocabulary worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.vocabularySize,1,1], localWorkSize=[16,1,1]) - // CUDA equivalent: kernel<<>> - int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; - WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); - vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); - - tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker); - tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker); - - return tornadoForwardScheduler; - } -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3TornadoVMLayerPlanner.java deleted file mode 100644 index 57d08a90..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3TornadoVMLayerPlanner.java +++ /dev/null @@ -1,383 +0,0 @@ -package org.beehive.gpullama3.tornadovm; - -import org.beehive.gpullama3.auxiliary.Tuple2; -import org.beehive.gpullama3.inference.state.Qwen3State; -import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; -import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.WorkerGrid; -import uk.ac.manchester.tornado.api.WorkerGrid1D; -import uk.ac.manchester.tornado.api.WorkerGrid2D; -import uk.ac.manchester.tornado.api.enums.DataTransferMode; - -import java.util.ArrayList; -import java.util.List; - -public class Qwen3TornadoVMLayerPlanner extends TornadoVMLayerPlanner { - - private final int nHeadKv; - private final int nEmbdHeadK; - private final int nEmbdHeadV; - private final int nEmbdVGqa; - private final int nEmbdHead; - private final int nEmbdGqa; - private final int gqa; - - public Qwen3TornadoVMLayerPlanner(Qwen3State state, Model model) { - super(state, model); - - this.nHeadKv = config.numberOfKeyValueHeads(); - this.nEmbdHeadK = config.numberOfHeadsKey(); - this.nEmbdHeadV = config.numberOfHeadsValue(); // n_embd_head_v = n_embd / n_head; %s.attention.value_length - this.nEmbdVGqa = nEmbdHeadV * nHeadKv; // n_embd_v_gqa = n_embd_head_v * n_head_kv - this.nEmbdHead = nEmbdHeadV; - this.nEmbdGqa = nEmbdVGqa; - this.gqa = config.numberOfHeads() / config.numberOfKeyValueHeads(); // integer multiplier of the kv sharing in multiquery - } - - // @formatter:off - @Override - protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { - if (layerIndex == 0) { - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, - state.positionHolder, state.temp, state.tempFFN, - state.tempQcur, state.tempKcur); // - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // - context, state.wrapXb, state.wrapXb2, // - state.wrapQ, state.wrapK, state.wrapV, // - state.wrapKeyCache, state.wrapValueCache, // - state.wrapAtt, state.wrapHb);// - } else { - // Subsequent layers: Consume data already on device from previous layer - unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // - state.wrapQ, state.wrapK, state.wrapV, // - state.wrapKeyCache, state.wrapValueCache, // - state.wrapAtt, state.wrapHb, // - state.positionHolder // - ); - } - return unifiedLayer; - } - // @formatter:on - - // @formatter:off - @Override - public Tuple2, GridScheduler> setupTornadoForwardPlanLayered() { - List taskGraphs = new ArrayList<>(); - - state.temp.init(0.0f); - state.tempFFN.init(0.0f); - state.tempLogits.init(0.0f); - state.wrapLogits.init(0.0f); - - TaskGraph activationUpdate = new TaskGraph("activationUpdate") - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX) - .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX) - .persistOnDevice(state.wrapX); - taskGraphs.add(activationUpdate.snapshot()); - - TaskGraph unifiedLayer = null; - for (int layerIndex =0; layerIndex < config.numberOfLayers(); layerIndex++) { - unifiedLayer = new TaskGraph("layer_" + layerIndex); - unifiedLayer.consumeFromDevice(state.wrapX); - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, - //Copy-in weights per layer for batched-layered layout - weights.rms_att_weightLayered[layerIndex], - weights.wqLayered[layerIndex], - weights.wkLayered[layerIndex], - weights.wvLayered[layerIndex], - weights.woLayered[layerIndex], - //rms_att_KNormLayered - weights.rms_att_KNormLayered[layerIndex], - //rms_att_QNormLayered - weights.rms_att_QNormLayered[layerIndex], - weights.rms_ffn_weightLayered[layerIndex], - weights.w1Layered[layerIndex], - weights.w2Layered[layerIndex], - weights.w3Layered[layerIndex] - ); - unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer.task("reductionsOneBlock", - TransformerComputeKernelsLayered::reductionOneBlockWithLayer, - context, - state.temp, - state.wrapX, // in - config.dim(), - config.rmsNormEps(), - state.localSize) - .task("mapContext", - TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, - context, - state.wrapXb, // out - state.wrapX, - weights.rms_att_weightLayered[layerIndex], - state.temp); - - int qDim0 = nEmbdHeadK * config.numberOfHeads(); - int kvDim0 = nEmbdGqa; - int qkvDim1 = config.dim(); - unifiedLayer.task("qmatmul", - TransformerComputeKernelsLayered::matrixVectorGeneric, - context, - state.wrapXb, - state.wrapQ, // output - weights.wqLayered[layerIndex], - qkvDim1, - qDim0, - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("kmatmul", - TransformerComputeKernelsLayered::matrixVectorGeneric, - context, - state.wrapXb, - state.wrapK, // output - weights.wkLayered[layerIndex], - qkvDim1, - kvDim0, - LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("vmatmul", - TransformerComputeKernelsLayered::matrixVectorGeneric, - context, - state.wrapXb, - state.wrapV, // output - weights.wvLayered[layerIndex], - qkvDim1, - kvDim0, - LOCAL_WORK_GROUP_SIZE_ALLOC); - - // Qcur rmsnorm - unifiedLayer - .task("rmsnormReduction_Qcur", - Qwen3Kernels::rmsnormWithParallelOffset, - context, - state.tempQcur, // output - state.wrapQ, // input - state.localSize, // currently 128, should be variable of global nEmbHead - nEmbdHead, // for normalization - config.rmsNormEps()) // for normalization - .task("rmsnormMapIndexInPlace_Qcur", - Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset, - context, - state.wrapQ, // output - weights.rms_att_QNormLayered[layerIndex], - nEmbdHead, - state.tempQcur); - - // Kcur rmsnorm - unifiedLayer - .task("rmsnormReduction_Kcur", - Qwen3Kernels::rmsnormWithParallelOffset, - context, - state.tempKcur, // output - state.wrapK, // input - state.localSize, // currently 128, should be variable of global nEmbHead - nEmbdHead, // for normalization - config.rmsNormEps()) // for normalization - .task("rmsnormMapIndexInPlace_Kcur", - Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset, - context, - state.wrapK, // output - weights.rms_att_KNormLayered[layerIndex], - nEmbdHead, - state.tempKcur); - - // rope rotation task graph - unifiedLayer.task("ropeRotation", - Qwen3Kernels::ropeRotation, - context, - state.positionHolder, - state.wrapQ, // out - state.wrapK, // out - config.numberOfKeyValueHeads(), - nEmbdHead); - - unifiedLayer.task("copyToCaches", - TransformerComputeKernelsLayered::copyToCache, - state.wrapKeyCache, // out - state.wrapK, // in - state.wrapValueCache, // out - state.wrapV, // in - state.positionHolder, - nEmbdGqa, - layerIndex, - config.contextLength()); - - unifiedLayer.task("parallel-attention", - TransformerComputeKernelsLayered::processHeadsFlashAttentionOpt, - context, - state.wrapQ, - state.wrapKeyCache, - state.wrapValueCache, - state.wrapXb, // out - config.numberOfHeads(), - nEmbdHead, - nEmbdGqa, - gqa, - state.positionHolder, - layerIndex, - config.contextLength()); - - unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, - context, - state.wrapXb, // vector - state.wrapX, // out, should be [1024] - weights.woLayered[layerIndex], // matrix - nEmbdHeadK * config.numberOfHeads(), // dim1 = 2048 - config.dim(), // dim0 = 1024 - LOCAL_WORK_GROUP_SIZE_ALLOC); - - unifiedLayer.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, - context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("reductionFinalNormalizationFFN" , TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, - config.dim(), config.rmsNormEps()) - .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, - state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN); - - unifiedLayer.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, - state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex], weights.w3Layered[layerIndex], config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, - state.wrapHb, state.wrapX, weights.w2Layered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .persistOnDevice( - state.wrapX - ); - taskGraphs.add(unifiedLayer.snapshot()); - } - - TaskGraph lastUnifiedLayer = unifiedLayer; - TaskGraph logits = new TaskGraph("logits") - .consumeFromDevice(lastUnifiedLayer.getTaskGraphName(), - state.wrapX - ) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, - state.tempLogits, - state.wrapLogits - ) - .transferToDevice(DataTransferMode.FIRST_EXECUTION, - context, - weights.wclsHalfFloat, - weights.rms_final_weight_as_floatArray - ) - .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, - context, - state.tempLogits, - state.wrapX, - config.dim(), - config.rmsNormEps(), - state.localSize) - .task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, - weights.rms_final_weight_as_floatArray, state.tempLogits); - logits = configureQuantizedMatrixVectorFinalWeight(logits); - logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); - taskGraphs.add(logits.snapshot()); - - return new Tuple2<>(taskGraphs, setupQwen3GridSchedulersLayeredNonNvidia()); - - } - // @formatter:on - - @Override - public Tuple2, GridScheduler> setupTornadoForwardPlanLayeredNonNvidia() { - return setupTornadoForwardPlanLayered(); - } - - private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() { - GridScheduler gridScheduler = new GridScheduler(); - - WorkerGrid singleWorker = new WorkerGrid1D(1); - singleWorker.setGlobalWork(1, 1, 1); - singleWorker.setLocalWork(1, 1, 1); - - WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim()); - rmsNormWorker.setGlobalWork(config.dim(), 1, 1); // Set global work size to total dimension - rmsNormWorker.setLocalWork(state.localSize, 1, 1); // Set local work size to 256 (standard efficient size) - - int matmulQGlobal = nEmbdHeadK * config.numberOfHeads() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid matmulQRowMajorWorker = new WorkerGrid1D(matmulQGlobal); - matmulQRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - int matmulKVGlobal = nEmbdGqa * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid matmulKVRowMajorWorker = new WorkerGrid1D(matmulKVGlobal); - matmulKVRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - WorkerGrid curWorker = new WorkerGrid1D(nEmbdHead); // mEmbdHead = 128 - curWorker.setGlobalWork(nEmbdHead, 1, 1); // Set global work size to total dimension - curWorker.setLocalWork(128, 1, 1); // Set local work size to 256 (standard efficient size) - - // Qcur - WorkerGrid qCurWorker = new WorkerGrid1D(config.numberOfHeads() * nEmbdHead); - qCurWorker.setLocalWork(nEmbdHead, 1, 1); - - // Kcur - WorkerGrid kCurWorker = new WorkerGrid1D(config.numberOfKeyValueHeads() * nEmbdHead); - kCurWorker.setLocalWork(nEmbdHead, 1, 1); - - int h = config.numberOfHeads(); - int ic = nEmbdHead / 2; - WorkerGrid ropeWorker = new WorkerGrid2D(h, ic); - ropeWorker.setGlobalWork(h, ic, 1); - ropeWorker.setLocalWork(8, 1, 1); - - WorkerGrid copyToCachesWorker = new WorkerGrid1D(nEmbdGqa); - copyToCachesWorker.setGlobalWork(nEmbdGqa, 1, 1); - copyToCachesWorker.setLocalWork(128, 1, 1); - - // Parallel attention worker configuration - WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads()); - parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * 32, 1, 1); - parallelAttentionWorker.setLocalWork(32, 1, 1); - - int matmul1Global = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid matmul1Worker = new WorkerGrid1D(matmul1Global); - matmul1Worker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - int fusedFFNW1W3Global = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid fusedFFNW1W3Worker = new WorkerGrid1D(fusedFFNW1W3Global); - fusedFFNW1W3Worker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - int projectionTwoGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid projectionTwoWorker = new WorkerGrid1D(projectionTwoGlobal); - projectionTwoWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - // Map workers to tasks - gridScheduler.addWorkerGrid("activationUpdate.updateX", singleWorker); - for (int i = 0; i < config.numberOfLayers(); i++) { - gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); - - gridScheduler.addWorkerGrid("layer_" + i + ".qmatmul", matmulQRowMajorWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".kmatmul", matmulKVRowMajorWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".vmatmul", matmulKVRowMajorWorker); - - // Qcur - gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Qcur", qCurWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Qcur", qCurWorker); - - // Kcur - gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Kcur", kCurWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Kcur", kCurWorker); - - gridScheduler.addWorkerGrid("layer_" + i + ".ropeRotation", ropeWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker); - gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); - gridScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", fusedFFNW1W3Worker); - gridScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", projectionTwoWorker); - } - - int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; - WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); - vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); - - gridScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker); - gridScheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker); - - gridScheduler.addWorkerGrid("logits.projection", vocabWorker); - - return gridScheduler; - } - -} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java deleted file mode 100644 index 54764389..00000000 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java +++ /dev/null @@ -1,543 +0,0 @@ -package org.beehive.gpullama3.tornadovm; - -import org.beehive.gpullama3.auxiliary.Tuple2; -import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; -import org.beehive.gpullama3.model.Configuration; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.inference.state.State; -import uk.ac.manchester.tornado.api.GridScheduler; -import uk.ac.manchester.tornado.api.ImmutableTaskGraph; -import uk.ac.manchester.tornado.api.KernelContext; -import uk.ac.manchester.tornado.api.TaskGraph; -import uk.ac.manchester.tornado.api.WorkerGrid; -import uk.ac.manchester.tornado.api.WorkerGrid1D; -import uk.ac.manchester.tornado.api.enums.DataTransferMode; - -import java.util.ArrayList; -import java.util.List; - -// @formatter:off - /** - * TornadoVMLayerPlanner orchestrates the execution planning for transformer model inference - * on GPU using the TornadoVM framework. - * - * This class is responsible for: - * - Creating task graphs for each layer of the neural network - * - Managing GPU memory transfers between layers - * - Configuring worker grids for optimal GPU utilization - * - Setting up the execution schedule for the entire forward pass - * - * The planner implements a layered approach where: - * - Each layer is represented as a separate TaskGraph - * - Data transfers are optimized to minimize host-device communication - * - Worker grids are configured for different types of operations (attention, FFN, etc.) - * - The entire pipeline is scheduled to run efficiently on GPU - * - * Key optimizations include: - * - One-time transfer of static data (weights, caches) - * - Per-execution transfer of dynamic data (position, activations) - * - Device-to-device data consumption between layers - * - Parallelized attention computation across heads - * - * @see TaskGraph - * @see GridScheduler - */ - // @formatter:on - public class TornadoVMLayerPlanner { - protected static final int LOCAL_WORK_GROUP_SIZE_ALLOC = 32; - protected static final int THREAD_SCALE_FOR_LOGITS = 8; - - protected final S state; - protected final C config; - protected final W weights; - protected final KernelContext context; - - /** - * Constructs a TornadoVMLayerPlanner for the given Llama model. - * - * @param state - * The state object containing model tensors and buffers - * @param model - * The Llama model instance containing configuration and weights - */ - public TornadoVMLayerPlanner(S state, Model model) { - this.state = state; - this.config = (C) model.configuration(); - this.weights = (W) model.weights(); - this.context = new KernelContext(); - } - - public Tuple2, GridScheduler> setupTornadoForwardPlanLayered() { - List taskGraphs = new ArrayList<>(); - - state.temp.init(0.0f); - state.tempFFN.init(0.0f); - state.tempLogits.init(0.0f); - - // @formatter:off - TaskGraph activationUpdate = new TaskGraph("activationUpdate") - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX) - .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX) - .persistOnDevice(state.wrapX); - taskGraphs.add(activationUpdate.snapshot()); - - TaskGraph unifiedLayer = null; - for (int layerIndex =0; layerIndex < config.numberOfLayers(); layerIndex++) { - unifiedLayer = new TaskGraph("layer_" + layerIndex); - unifiedLayer.consumeFromDevice(state.wrapX); - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, - //Copy-in weights per layer for batched-layered layout - weights.rms_att_weightLayered[layerIndex], - weights.wqLayered[layerIndex], - weights.wkLayered[layerIndex], - weights.wvLayered[layerIndex], - weights.woLayered[layerIndex], - weights.rms_ffn_weightLayered[layerIndex], - weights.w1Layered[layerIndex], - weights.w2Layered[layerIndex], - weights.w3Layered[layerIndex] - ); - unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer.task("reductionsOneBlock" , TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, - state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, - state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp) - .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, - state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, - state.wrapXb, state.wrapK, weights.wkLayered[layerIndex], config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, - state.wrapXb, state.wrapV, weights.wvLayered[layerIndex], config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("rope", TransformerComputeKernelsLayered::ropeRotation,context, - state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), - config.headSize()) - .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, - state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength()) - .task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, context, - state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), - state.positionHolder, layerIndex, config.contextLength()) - .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, - state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, - state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, - state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN) - .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, - state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex], weights.w3Layered[layerIndex], config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, - state.wrapHb, state.wrapX, weights.w2Layered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .persistOnDevice( - state.wrapX - ); - taskGraphs.add(unifiedLayer.snapshot()); - } - - TaskGraph lastUnifiedLayer = unifiedLayer; - TaskGraph logits = new TaskGraph("logits") - .consumeFromDevice(lastUnifiedLayer.getTaskGraphName(), - state.wrapX - ) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, - state.tempLogits - ) - .transferToDevice(DataTransferMode.FIRST_EXECUTION, - context, - state.wrapLogits, - weights.wclsHalfFloat, - weights.rms_final_weight_as_floatArray - ) - .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, - state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, - weights.rms_final_weight_as_floatArray, state.tempLogits); - logits = configureQuantizedMatrixVectorFinalWeight(logits); - logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); - taskGraphs.add(logits.snapshot()); - // @formatter:on - - return new Tuple2<>(taskGraphs, setupGridSchedulersLayered()); - } - - // @formatter:off - /** - * Configures the final projection layer in the task graph based on weight quantization type. - * - * This method adds a "projection" task to compute the final logits by performing a - * matrix-vector multiplication between the model's output embeddings and the classifier - * weights (wcls). The computation kernel used depends on the quantization format. - * - * Supported quantization types: - * - Q8_0: 8-bit quantization with uniform scaling per 32-element block - * - Q4_0: 4-bit quantization with uniform scaling per 32-element block - * - * The task multiplies: - * - weights.wclsByteArray: Quantized classifier weights (vocab_size x dim) - * - state.wrapX: Current layer output (dim) - * - Result: state.wrapLogits: Raw logits (vocab_size) - * - * @param logits The existing task graph to extend with the projection operation - * @return The modified task graph with the projection task added - * @throws UnsupportedOperationException If weights.weightType is not Q8_0 or Q4_0 - */ - // @formatter:on - protected TaskGraph configureQuantizedMatrixVectorFinalWeight(TaskGraph logits) { - switch (weights.getWeightType()) { - case F16: - case Q8_0: - case Q4_0: - logits.task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, // - context, state.wrapX, state.wrapLogits, weights.wclsHalfFloat, // - config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); // - break; - default: - throw new UnsupportedOperationException("Unsupported weight quantization type: " + weights.getWeightType() + ". Only Q8_0 and Q4_0 are supported."); - } - return logits; - } - - /** - * Configures data transfer operations for a specific layer in the neural network task graph. - * - * This method manages GPU memory transfers with optimized data movement strategies: - * This optimization pattern minimizes data movement by: - * 1. Using one-time transfers for static data - * 2. Reusing intermediate results already on GPU from previous layers - * 3. Only transferring // - * dynamic data that changes per execution - * - * @param unifiedLayer - * The task graph representing this layer's operations - * @param layerIndex - * Index of the current layer (0-based) - * @return The configured task graph with appropriate data transfer operations - */ - protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { - // First layer: Transfer initial data to device (one-time transfer) - if (layerIndex == 0) { - // Transfer all attention-related data: query, key, value matrices and their caches - unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); // - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // - context, state.wrapXb, state.wrapXb2, // - state.wrapQ, state.wrapK, state.wrapV, // - state.wrapKeyCache, state.wrapValueCache, // - state.wrapAtt, state.wrapHb); // - } else { - // Subsequent layers: Consume data already on device from previous layer - unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // - state.wrapQ, state.wrapK, state.wrapV, // - state.wrapKeyCache, state.wrapValueCache, // - state.wrapAtt, state.wrapHb, // - state.positionHolder // - ); - } - return unifiedLayer; - } - - // @formatter:off - /** - * Sets up the grid scheduler configuration for a layered neural network forward pass. - * - * This method creates and configures worker grids for different types of GPU operations - * in the transformer/ML model pipeline. Each worker grid defines how work should be - * distributed across GPU threads (OpenCL work-items or CUDA threads). - * - * The method creates several worker profiles: - * - Single thread operations (activation updates) - * - RoPE (Rotary Position Embedding) operations - * - Matrix multiplications with different dimensions - * - RMS normalization operations - * - Parallel attention computations - * - Cache copying operations - * - Vocabulary projections - * - * Each worker grid maps to equivalent OpenCL NDRange or CUDA grid/block configurations: - * - setGlobalWork() ≈ OpenCL global_work_size ≈ CUDA grid dimensions × block dimensions - * - setLocalWork() ≈ OpenCL local_work_size ≈ CUDA block dimensions - * - * @return GridScheduler configured with all necessary worker grids for the model layers - */ - // @formatter:on - private GridScheduler setupGridSchedulersLayered() { - GridScheduler tornadoForwardScheduler = new GridScheduler(); - - // Single worker for tasks running with a single thread - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[1,1,1], localWorkSize=[1,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid singleWorker = new WorkerGrid1D(1); - singleWorker.setGlobalWork(1, 1, 1); - singleWorker.setLocalWork(1, 1, 1); - - // config.dim / 2 Worker for RoPE - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim/2,1,1], localWorkSize=[128,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid ropeWorker = new WorkerGrid1D(config.dim() / 2); - ropeWorker.setGlobalWork(config.dim() / 2, 1, 1); - ropeWorker.setLocalWork(128, 1, 1); - - // config.dim Worker for Row major access - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) - // CUDA equivalent: kernel<<>> - int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configDimRowMajorGlobalWorker = new WorkerGrid1D(configDimRowMajorGlobal); - configDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - // config.kvDim Worker for Row major access - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.kvDim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) - // CUDA equivalent: kernel<<>> - int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configKvDimRowMajorGlobalWorker = new WorkerGrid1D(configKvDimRowMajorGlobal); - configKvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - // config.hiddenDim * 32 Worker for Row major access - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.hiddenDim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) - // CUDA equivalent: kernel<<>> - int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configHiddenDimRowMajorWorker = new WorkerGrid1D(configHiddenDimRowMajor); - configHiddenDimRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - // RMSNorm worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[256,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim()); - rmsNormWorker.setGlobalWork(config.dim(), 1, 1); // Set global work size to total dimension - rmsNormWorker.setLocalWork(256, 1, 1); // Set local work size to 256 (standard efficient size) - - // Parallel attention worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.numberOfHeads,1,1], localWorkSize=[4,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads()); - // the global group work size is numberOfHeads * localWorkGroupSize, where the localWorkGroupSize is currently 4 - parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * 8, 1, 1); - parallelAttentionWorker.setLocalWork(8, 1, 1); // Set local work size to 4 (for parallel attention) - - // Copy to caches worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[128,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid copyToCachesWorker = new WorkerGrid1D(config.kvDim()); - copyToCachesWorker.setGlobalWork(config.dim(), 1, 1); - copyToCachesWorker.setLocalWork(128, 1, 1); // Set local work size to 32 (for copying to caches) - - // Map workers to tasks - tornadoForwardScheduler.addWorkerGrid("activationUpdate.updateX", singleWorker); - for (int i = 0; i < config.numberOfLayers(); i++) { - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); - } - - // Vocabulary worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.vocabularySize,1,1], localWorkSize=[16,1,1]) - // CUDA equivalent: kernel<<>> - int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; - WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); - vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); - - tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker); - tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker); - - return tornadoForwardScheduler; - } - - private GridScheduler setupGridSchedulersLayeredNonNvidia() { - GridScheduler tornadoForwardScheduler = new GridScheduler(); - - // Single worker for tasks running with a single thread - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[1,1,1], localWorkSize=[1,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid singleWorker = new WorkerGrid1D(1); - singleWorker.setGlobalWork(1, 1, 1); - singleWorker.setLocalWork(1, 1, 1); - - // config.dim / 2 Worker for RoPE - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim/2,1,1], localWorkSize=[128,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid ropeWorker = new WorkerGrid1D(config.dim() / 2); - ropeWorker.setGlobalWork(config.dim() / 2, 1, 1); - ropeWorker.setLocalWork(128, 1, 1); - - // config.dim Worker for Row major access - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) - // CUDA equivalent: kernel<<>> - int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configDimRowMajorGlobalWorker = new WorkerGrid1D(configDimRowMajorGlobal); - configDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - // config.kvDim Worker for Row major access - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.kvDim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) - // CUDA equivalent: kernel<<>> - int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configKvDimRowMajorGlobalWorker = new WorkerGrid1D(configKvDimRowMajorGlobal); - configKvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - // config.hiddenDim * 32 Worker for Row major access - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.hiddenDim*LOCAL_WORK_GROUP_SIZE_ALLOC,1,1], localWorkSize=[LOCAL_WORK_GROUP_SIZE_ALLOC,1,1]) - // CUDA equivalent: kernel<<>> - int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; - WorkerGrid configHiddenDimRowMajorWorker = new WorkerGrid1D(configHiddenDimRowMajor); - configHiddenDimRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); - - // RMSNorm worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[256,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim()); - rmsNormWorker.setGlobalWork(config.dim(), 1, 1); // Set global work size to total dimension - rmsNormWorker.setLocalWork(256, 1, 1); // Set local work size to 256 (standard efficient size) - - // Parallel attention worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.numberOfHeads,1,1], localWorkSize=[4,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads()); - // the global group work size is numberOfHeads * localWorkGroupSize, where the localWorkGroupSize is currently 4 - parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * 8, 1, 1); - parallelAttentionWorker.setLocalWork(8, 1, 1); // Set local work size to 4 (for parallel attention) - - // Copy to caches worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[128,1,1]) - // CUDA equivalent: kernel<<>> - WorkerGrid copyToCachesWorker = new WorkerGrid1D(config.kvDim()); - copyToCachesWorker.setGlobalWork(config.dim(), 1, 1); - copyToCachesWorker.setLocalWork(128, 1, 1); // Set local work size to 32 (for copying to caches) - - // Map workers to tasks - tornadoForwardScheduler.addWorkerGrid("activationUpdate.updateX", singleWorker); - for (int i = 0; i < config.numberOfLayers(); i++) { - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); - } - - // Vocabulary worker configuration - // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.vocabularySize,1,1], localWorkSize=[16,1,1]) - // CUDA equivalent: kernel<<>> - int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; - WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); - vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); - - tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker); - tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker); - - return tornadoForwardScheduler; - } - - public Tuple2, GridScheduler> setupTornadoForwardPlanLayeredNonNvidia() { - List taskGraphs = new ArrayList<>(); - - state.temp.init(0.0f); - state.tempFFN.init(0.0f); - state.tempLogits.init(0.0f); - - // @formatter:off - TaskGraph activationUpdate = new TaskGraph("activationUpdate") - .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX) - .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX) - .persistOnDevice(state.wrapX); - taskGraphs.add(activationUpdate.snapshot()); - - TaskGraph unifiedLayer = null; - for (int layerIndex =0; layerIndex < config.numberOfLayers(); layerIndex++) { - unifiedLayer = new TaskGraph("layer_" + layerIndex); - unifiedLayer.consumeFromDevice(state.wrapX); - unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, - //Copy-in weights per layer for batched-layered layout - weights.rms_att_weightLayered[layerIndex], - weights.wqLayered[layerIndex], - weights.wkLayered[layerIndex], - weights.wvLayered[layerIndex], - weights.woLayered[layerIndex], - weights.rms_ffn_weightLayered[layerIndex], - weights.w1Layered[layerIndex], - weights.w2Layered[layerIndex], - weights.w3Layered[layerIndex] - ); - unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); - unifiedLayer.task("reductionsOneBlock" , TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, - state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("reductionFinalNormalization" , TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, - config.dim(), config.rmsNormEps()) - .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, - state.wrapX, weights.rms_att_weightLayered[layerIndex], state.temp) - .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, - state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, - state.wrapXb, state.wrapK, weights.wkLayered[layerIndex], config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, - state.wrapXb, state.wrapV, weights.wvLayered[layerIndex], config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("rope", TransformerComputeKernelsLayered::ropeRotation,context, - state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), - config.headSize()) - .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, - state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength()) - .task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel, - state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, - config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.vocabularySize(), - state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()) - .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, - state.wrapXb, state.wrapX, weights.woLayered[layerIndex], config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, - state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("reductionFinalNormalizationFFN" , TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, - config.dim(), config.rmsNormEps()) - .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, - state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN) - .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, - state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex], weights.w3Layered[layerIndex], config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, - state.wrapHb, state.wrapX, weights.w2Layered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) - .persistOnDevice( - state.wrapX - ); - taskGraphs.add(unifiedLayer.snapshot()); - } - - TaskGraph lastUnifiedLayer = unifiedLayer; - TaskGraph logits = new TaskGraph("logits") - .consumeFromDevice(lastUnifiedLayer.getTaskGraphName(), - state.wrapX - ) - .transferToDevice(DataTransferMode.EVERY_EXECUTION, - state.tempLogits - ) - .transferToDevice(DataTransferMode.FIRST_EXECUTION, - context, - state.wrapLogits, - weights.wclsHalfFloat, - weights.rms_final_weight_as_floatArray - ) - .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, - state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) - .task("reductionFinalNormalizationLogits" , TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, - config.dim(), config.rmsNormEps()) - .task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, - weights.rms_final_weight_as_floatArray, state.tempLogits); - logits = configureQuantizedMatrixVectorFinalWeight(logits); - logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); - taskGraphs.add(logits.snapshot()); - // @formatter:on - - return new Tuple2<>(taskGraphs, setupGridSchedulersLayeredNonNvidia()); - } - - } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index 58f725af..293d2c0c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -1,42 +1,27 @@ package org.beehive.gpullama3.tornadovm; -import org.beehive.gpullama3.auxiliary.Tuple2; -import org.beehive.gpullama3.inference.state.Phi3State; -import org.beehive.gpullama3.inference.state.Qwen2State; -import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.ModelType; -import uk.ac.manchester.tornado.api.GridScheduler; +import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizationPlannerFactory; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TornadoExecutionPlan; -import uk.ac.manchester.tornado.api.TornadoRuntime; -import uk.ac.manchester.tornado.api.runtime.TornadoRuntimeProvider; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import java.util.List; -import java.util.Locale; - public class TornadoVMMasterPlan { public static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); private final State state; private final Configuration config; - public GridScheduler scheduler; public TornadoExecutionPlan executionPlan; - List taskGraphs; + GenericLayerPlanner tornadoVMLayerPlanner; public TornadoVMMasterPlan(State state, Model model) { - TornadoVMLayerPlanner tornadoVMLayerPlanner = createPlanner(state, model); - Tuple2, GridScheduler> tornadoVMPlan = shouldUseNvidiaScheduler(model) - ? tornadoVMLayerPlanner.setupTornadoForwardPlanLayered() - : tornadoVMLayerPlanner.setupTornadoForwardPlanLayeredNonNvidia(); - this.taskGraphs = tornadoVMPlan.getFirst(); - this.scheduler = tornadoVMPlan.getSecond(); + this.tornadoVMLayerPlanner = createPlanner(state, model); + this.executionPlan = createExecutionPlan(); this.state = state; this.config = model.configuration(); - this.executionPlan = new TornadoExecutionPlan(taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()])); } /** @@ -93,17 +78,21 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod return tornadoVMPlan; } - /** - * Dispatcher method to select the TornadoVMLayerPlanner for the model. - */ - TornadoVMLayerPlanner createPlanner(State state, Model model) { - return switch (model.getModelType()) { - case LLAMA_3, MISTRAL -> new TornadoVMLayerPlanner(state, model); - case PHI_3 -> new Phi3TornadoVMLayerPlanner((Phi3State) state, model); - case QWEN_2, DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2TornadoVMLayerPlanner((Qwen2State) state, model); - case QWEN_3 -> new Qwen3TornadoVMLayerPlanner((Qwen3State) state, model); - case UNKNOWN -> throw new UnsupportedOperationException("Unknown model type"); - }; + private TornadoExecutionPlan createExecutionPlan() { + var taskGraphs = tornadoVMLayerPlanner.getImmutableTaskGraphs(); + var taskGraphArray = taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()]); + return new TornadoExecutionPlan(taskGraphArray); + } + + private GenericLayerPlanner createPlanner(State state, Model model) { + // ========== STEP 1: Detect Quantization Type ========== + GGMLType weightType = model.weights().getWeightType(); + + // ========== STEP 2: Route via Factory ========== + // Factory handles all model × quantization combinations + GenericLayerPlanner basePlanner = QuantizationPlannerFactory.create(weightType, state, model); + + return basePlanner; } /** @@ -117,21 +106,9 @@ TornadoVMLayerPlanner createPlanner(State state, Model model) { * the model whose type may affect the scheduler decision * @return {@code true} if the NVIDIA-specific scheduler should be used; {@code false} otherwise */ - public static boolean shouldUseNvidiaScheduler(Model model) { - TornadoRuntime runtime = TornadoRuntimeProvider.getTornadoRuntime(); - String platformName = runtime.getBackend(0).getDefaultDevice().getPlatformName().toLowerCase(Locale.ROOT); - - boolean isNvidia = platformName.contains("nvidia"); - boolean isNotMistral = model.getModelType() != ModelType.MISTRAL; - - boolean result = isNvidia && isNotMistral; - - return result; - } /** - * Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration. - *This method processes the transformer layers in sequence for a particular token position in the context + * Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration. This method processes the transformer layers in sequence for a particular token position in the context * window. * *

The execution happens in three phases: @@ -146,11 +123,12 @@ public static boolean shouldUseNvidiaScheduler(Model model) { * @return FloatTensor containing the output logits for token prediction */ + // int pos, ModelPlanner public FloatArray tornadoVMForwardExecuteLayered(int position) { // @formatter:off // 1. Execute the preprocessing graph (e.g., input preparation, memory initialization) executionPlan.withGraph(getPreprocessingGraphIndex()) - .withGridScheduler(scheduler) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) .execute(); // Set the position in the state object (used by attention layers) @@ -160,13 +138,13 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) { // Each graph computes attention and feed-forward transformations for one layer for (int layer = 0; layer < config.numberOfLayers(); layer++) { executionPlan.withGraph(getLayerGraphIndex(layer)) - .withGridScheduler(scheduler) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) .execute(); } // 3. Execute the final graph that projects the last hidden state to output logits executionPlan.withGraph(getFinalLogitsGraphIndex()) - .withGridScheduler(scheduler) + .withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()) .execute(); // @formatter:on @@ -195,7 +173,7 @@ private int getLayerGraphIndex(int layerIndex) { * Returns the graph index for the final projection to logits. */ private int getFinalLogitsGraphIndex() { - return taskGraphs.size() - 1; + return tornadoVMLayerPlanner.getImmutableTaskGraphs().size() - 1; } /// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration just once to copy the data into the read-only data layer. @@ -205,15 +183,15 @@ public void forceCopyInReadOnlyDataLayered() { state.positionHolder.init(0); // Execute activation update graph - executionPlan.withGraph(0).withGridScheduler(scheduler).execute(); + executionPlan.withGraph(0).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); // Execute layer processing graphs for (int layer = 0; layer < config.numberOfLayers(); layer++) { - executionPlan.withGraph(layer + 1).withGridScheduler(scheduler).execute(); + executionPlan.withGraph(layer + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); } // Execute logits graph - executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(scheduler).execute(); + executionPlan.withGraph(config.numberOfLayers() + 1).withGridScheduler(tornadoVMLayerPlanner.getGridScheduler()).execute(); } /** diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2Kernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen2Kernels.java similarity index 99% rename from src/main/java/org/beehive/gpullama3/tornadovm/Qwen2Kernels.java rename to src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen2Kernels.java index 2b69d296..455be76a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2Kernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen2Kernels.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.tornadovm; +package org.beehive.gpullama3.tornadovm.kernels; import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.math.TornadoMath; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3Kernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java similarity index 99% rename from src/main/java/org/beehive/gpullama3/tornadovm/Qwen3Kernels.java rename to src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java index f09696c4..930e1774 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3Kernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.tornadovm; +package org.beehive.gpullama3.tornadovm.kernels; import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.annotations.Parallel; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TransformerComputeKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java similarity index 98% rename from src/main/java/org/beehive/gpullama3/tornadovm/TransformerComputeKernels.java rename to src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java index 7b4f6112..7f69e496 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TransformerComputeKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.tornadovm; +package org.beehive.gpullama3.tornadovm.kernels; import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.math.TornadoMath; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java similarity index 87% rename from src/main/java/org/beehive/gpullama3/tornadovm/TransformerComputeKernelsLayered.java rename to src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index eedae53c..dfe4ef27 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -1,10 +1,11 @@ -package org.beehive.gpullama3.tornadovm; +package org.beehive.gpullama3.tornadovm.kernels; import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.math.TornadoMath; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.Int8Array; import uk.ac.manchester.tornado.api.types.arrays.IntArray; public class TransformerComputeKernelsLayered { @@ -137,13 +138,6 @@ public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, Float } } - public static void copyTo(FloatArray src, int srcOffset, FloatArray dest, int destOffset, int size) { - // Generic copy: src[srcOffset:srcOffset+size] -> dest[destOffset:destOffset+size] - for (@Parallel int i = 0; i < size; i++) { - dest.set(destOffset + i, src.get(srcOffset + i)); - } - } - public static void splitQKV(FloatArray qkv, FloatArray q, FloatArray k, FloatArray v, int dimQ, int dimKV) { int totalSize = dimQ + 2 * dimKV; @@ -253,51 +247,6 @@ public static void ropeRotationPhi3(KernelContext context, IntArray positionHold } } - /** - * Orchestrates parallel multi-head attention computation across all heads. Each head processes attention independently in parallel. - * - * Attention computation: 1. Compute attention scores (Q·K) 2. Apply softmax for attention weights 3. Compute weighted sum of values (attention·V) - * - * @param q - * Query vectors for all heads - * @param key_cache - * Cached key vectors - * @param value_cache - * Cached value vectors - * @param xb - * Output buffer for attention results - * @param nHeads - * Number of attention heads - * @param headSize - * Dimension of each head - * @param kvDim - * Total key/value dimension - * @param kvMul - * Key/value head multiplier for grouped-query attention - * @param seqLen - * Current sequence length - * @param positionHolder - * Array containing position and layer info - * @param wrapAtt - * Buffer for attention weights - * @param layer - * Current transformer layer - * @param contextLength - * Maximum context length - */ - public static void processHeadsParallel(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen, - IntArray positionHolder, FloatArray wrapAtt, int layer, int contextLength) { - - int pos = positionHolder.get(0); - int loff = layer * contextLength * kvDim; - - // Parallelize computation across attention heads - for (@Parallel int h = 0; h < nHeads; h++) { - // Process each head in parallel - processHeadTornado(q, key_cache, value_cache, xb, h, headSize, kvDim, kvMul, loff, pos, wrapAtt); - } - } - /** * Computes attention for a single head. Implements scaled dot-product attention with softmax normalization. * @@ -971,4 +920,185 @@ public static void addInPlace(FloatArray arrayA, FloatArray arrayB, int size) { } } + /** + * Matrix-vector multiplication for Q8_0 quantized weights. + * + * @param context + * Kernel context + * @param x + * Input activations (FloatArray) + * @param output + * Output array (FloatArray) + * @param weightsQ + * Quantized weights (Int8Array) - from Q8_0QuantizedTensor.getQuants() + * @param weightScales + * Scale factors (HalfFloatArray) - from Q8_0QuantizedTensor.getScales() + * @param dim1 + * Input dimension (n - number of columns) + * @param dim0 + * Output dimension (d - number of rows) + * @param localWorkGroupSize + * Local workgroup size + */ + public static void matrixVectorGeneric(KernelContext context, FloatArray x, FloatArray output, Int8Array weightsQ, HalfFloatArray weightScales, int dim1, int dim0, int localWorkGroupSize) { + + // One row per workgroup + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Early exit if this workgroup is beyond output dimension + if (rowId >= dim0) { + return; + } + + float sum = matrixVectorRowMajorOptimizedQ8_0(context, localWorkGroupSize, x, weightsQ, weightScales, dim1); + + // Thread 0 writes the result + if (localId == 0) { + output.set(rowId, sum); + } + } + + /** + * Helper method to compute dot product for a single row with Q8_0 quantized weights. Uses 4-way unrolling for better performance. + */ + public static float matrixVectorRowMajorOptimizedQ8_0(KernelContext context, int localSize, FloatArray x, Int8Array weightsQ, HalfFloatArray weightScales, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + int blockSize = 32; + + // Allocate local memory for reduction + float[] localSums = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + int scalesRowOffset = rowId * (n / blockSize); + + // 4-way unrolling + float partialSum1 = 0.0f; + float partialSum2 = 0.0f; + float partialSum3 = 0.0f; + float partialSum4 = 0.0f; + + // Main loop - process 4 elements at a time + for (int j = localId * 4; j < n - 3; j += localSize * 4) { + int blockIdx = j / blockSize; + float scale = weightScales.get(scalesRowOffset + blockIdx).getFloat32(); + + // Dequantize and multiply + partialSum1 += ((float) weightsQ.get(rowOffset + j) * scale) * x.get(j); + partialSum2 += ((float) weightsQ.get(rowOffset + j + 1) * scale) * x.get(j + 1); + partialSum3 += ((float) weightsQ.get(rowOffset + j + 2) * scale) * x.get(j + 2); + partialSum4 += ((float) weightsQ.get(rowOffset + j + 3) * scale) * x.get(j + 3); + } + + float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4; + + // Handle remaining elements + for (int j = ((n / 4) * 4) + localId; j < n; j += localSize) { + int blockIdx = j / blockSize; + float scale = weightScales.get(scalesRowOffset + blockIdx).getFloat32(); + partialSum += ((float) weightsQ.get(rowOffset + j) * scale) * x.get(j); + } + + // Store partial sum + localSums[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSums[localId] += localSums[localId + stride]; + } + context.localBarrier(); + } + + return localSums[0]; + } + + public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, Int8Array w_quants, HalfFloatArray w_scales, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= d) { + return; + } + + float sum = matrixVectorRowMajorOptimizedQ8_0(context, localSize, x, w_quants, w_scales, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float result = hb.get(rowId) + sum; + hb.set(rowId, result); + } + } + + public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, Int8Array w1_quants, HalfFloatArray w1_scales, Int8Array w3_quants, + HalfFloatArray w3_scales, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= d) { + return; + } + + float sum1 = matrixVectorRowMajorOptimizedQ8_0(context, localWorkGroupSize, x, w1_quants, w1_scales, n); + float sum3 = matrixVectorRowMajorOptimizedQ8_0(context, localWorkGroupSize, x, w3_quants, w3_scales, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float silu = siluActivation(sum1); // Using the new SiLU method + float result = silu * sum3; + hb.set(rowId, result); + } + } + + /** + * Orchestrates parallel multi-head attention computation across all heads. Each head processes attention independently in parallel. + * + * Attention computation: 1. Compute attention scores (Q·K) 2. Apply softmax for attention weights 3. Compute weighted sum of values (attention·V) + * + * @param q + * Query vectors for all heads + * @param key_cache + * Cached key vectors + * @param value_cache + * Cached value vectors + * @param xb + * Output buffer for attention results + * @param nHeads + * Number of attention heads + * @param headSize + * Dimension of each head + * @param kvDim + * Total key/value dimension + * @param kvMul + * Key/value head multiplier for grouped-query attention + * @param seqLen + * Current sequence length + * @param positionHolder + * Array containing position and layer info + * @param wrapAtt + * Buffer for attention weights + * @param layer + * Current transformer layer + * @param contextLength + * Maximum context length + */ + public static void processHeadsParallel(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen, + IntArray positionHolder, FloatArray wrapAtt, int layer, int contextLength) { + + int pos = positionHolder.get(0); + int loff = layer * contextLength * kvDim; + + // Parallelize computation across attention heads + for (@Parallel int h = 0; h < nHeads; h++) { + // Process each head in parallel + processHeadTornado(q, key_cache, value_cache, xb, h, headSize, kvDim, kvMul, loff, pos, wrapAtt); + } + } + } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/WorkerGridFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/WorkerGridFactory.java new file mode 100644 index 00000000..af39c133 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/WorkerGridFactory.java @@ -0,0 +1,102 @@ +package org.beehive.gpullama3.tornadovm.layerplanner; + +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.WorkerGrid1D; +import uk.ac.manchester.tornado.api.WorkerGrid2D; + +public class WorkerGridFactory { + private static final int DEFAULT_WORK_GROUP_SIZE = 32; + + /** + * RMS Norm worker: parallel reduction across dimension // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[256,1,1]) // CUDA equivalent: + * kernel<<>> + */ + public static WorkerGrid createRmsNormWorker(int dim, int localSize) { + WorkerGrid worker = new WorkerGrid1D(dim); + worker.setGlobalWork(dim, 1, 1); + worker.setLocalWork(localSize, 1, 1); + return worker; + } + + // Single worker for tasks running with a single thread + // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[1,1,1], localWorkSize=[1,1,1]) + // CUDA equivalent: kernel<<>> + public static WorkerGrid createSingleWorker() { + WorkerGrid worker = new WorkerGrid1D(1); + worker.setGlobalWork(1, 1, 1); + worker.setLocalWork(1, 1, 1); + return worker; + } + + /** + * QKV matmul worker: combined projection output + */ + public static WorkerGrid createQkvMatmulWorker(int opSize) { + int global = opSize * DEFAULT_WORK_GROUP_SIZE; + WorkerGrid worker = new WorkerGrid1D(global); + worker.setLocalWork(DEFAULT_WORK_GROUP_SIZE, 1, 1); + return worker; + } + + public static WorkerGrid genericWorker(int globalWorkSize, int localWorkSize) { + WorkerGrid worker = new WorkerGrid1D(globalWorkSize); + worker.setLocalWork(localWorkSize, 1, 1); + return worker; + } + + /** + * RoPE worker: 2D grid for position encoding + */ + public static WorkerGrid createRoPEWorker(int numberOfHeads, int headSize) { + int ic = headSize / 2; + WorkerGrid worker = new WorkerGrid2D(numberOfHeads, ic); + worker.setGlobalWork(numberOfHeads, ic, 1); + worker.setLocalWork(8, 1, 1); + return worker; + } + + /** + * Attention worker: compute all heads in parallel + */ + public static WorkerGrid createAttentionWorker(int numberOfHeads, int headSize) { + int optimalLocalSize = findOptimalLocalSize(headSize); + WorkerGrid worker = new WorkerGrid1D(numberOfHeads); + worker.setGlobalWork(numberOfHeads * optimalLocalSize, 1, 1); + worker.setLocalWork(optimalLocalSize, 1, 1); + return worker; + } + + /** + * FFN gate+up worker: combined projection + */ + public static WorkerGrid createGateUpWorker(int hiddenDim) { + int global = (2 * hiddenDim) * DEFAULT_WORK_GROUP_SIZE; + WorkerGrid worker = new WorkerGrid1D(global); + worker.setLocalWork(DEFAULT_WORK_GROUP_SIZE, 1, 1); + return worker; + } + + /** + * FFN down worker: final projection + */ + public static WorkerGrid createDownWorker(int dim) { + int global = dim * DEFAULT_WORK_GROUP_SIZE; + WorkerGrid worker = new WorkerGrid1D(global); + worker.setLocalWork(DEFAULT_WORK_GROUP_SIZE, 1, 1); + return worker; + } + + private static int findOptimalLocalSize(int size) { + int optimal = Math.min(size, 64); + if (size % optimal != 0) { + for (int s = 64; s >= 1; s--) { + if (size % s == 0) { + optimal = s; + break; + } + } + } + return optimal; + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java new file mode 100644 index 00000000..1684a5b8 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java @@ -0,0 +1,84 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.base; + +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.state.Phi3State; +import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tornadovm.GenericLayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Phi3FP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen3FP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.LlamaQ8_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Phi3Q8_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen2Q8_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen3Q8_0LayerPlanner; + +/** + * Factory class responsible for creating appropriate layer planners based on model type and quantization. + *

+ * The factory follows a routing logic: + *

    + *
  1. Determine quantization type from {@link GGMLType}
  2. + *
  3. Determine model type from {@link Model}
  4. + *
  5. Instantiate appropriate planner implementation
  6. + *
+ *

+ * Examples: + *

    + *
  • {@code QuantizationType.FP16 + ModelType.LLAMA_3 → LlamaFP16LayerPlanner}
  • + *
  • {@code QuantizationType.Q8_0 + ModelType.QWEN_2 → Qwen2Q8_0LayerPlanner}
  • + *
+ */ +public class QuantizationPlannerFactory { + + /** + * Main factory method: create planner for given model + quantization + */ + public static GenericLayerPlanner create(GGMLType quantization, State state, Model model) { + return switch (quantization) { + case F32 -> createFP32Planner(state, model); + case F16 -> createFP16Planner(state, model); + case Q8_0 -> createQ8_0Planner(state, model); + case Q4_0 -> createQ4_0Planner(state, model); + default -> throw new UnsupportedOperationException("Quantization not supported: " + quantization); + }; + } + + // ============ FP16 Planners ============ + private static GenericLayerPlanner createFP16Planner(State state, Model model) { + return switch (model.getModelType()) { + case LLAMA_3, MISTRAL -> new LlamaFP16LayerPlanner((LlamaState) state, model); + case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model); + case QWEN_3 -> new Qwen3FP16LayerPlanner((Qwen3State) state, model); + case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model); + case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2FP16LayerPlanner((Qwen2State) state, model); + default -> throw new UnsupportedOperationException("FP16 not supported for model: " + model.getModelType()); + }; + } + + // ============ Q8_0 Planners ============ + private static GenericLayerPlanner createQ8_0Planner(State state, Model model) { + return switch (model.getModelType()) { + case LLAMA_3, MISTRAL -> new LlamaQ8_0LayerPlanner((LlamaState) state, model); + case QWEN_2 -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model); + case QWEN_3 -> new Qwen3Q8_0LayerPlanner((Qwen3State) state, model); + case PHI_3 -> new Phi3Q8_0LayerPlanner((Phi3State) state, model); + case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model); + default -> throw new UnsupportedOperationException("Q8_0 not supported for model: " + model.getModelType()); + }; + } + + // ============ FP32 Planners (FUTURE) ============ + private static GenericLayerPlanner createFP32Planner(State state, Model model) { + throw new UnsupportedOperationException("FP32 planners not yet implemented"); + } + + private static GenericLayerPlanner createQ4_0Planner(State state, Model model) { + throw new UnsupportedOperationException("Q4 planners not yet implemented"); + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java new file mode 100644 index 00000000..f95d5406 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java @@ -0,0 +1,65 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.base; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tornadovm.GenericLayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import uk.ac.manchester.tornado.api.KernelContext; + +/** + * Abstract base for all quantization-specific planners. + * + * Contains shared logic that works regardless of model type but depends on quantization. Subclasses: FP16LayerPlanner, Q8_0LayerPlanner, etc. + */ +public abstract class QuantizedLayerPlanner implements GenericLayerPlanner { + + // Common state for all quantizations + protected static final int LOCAL_WORK_GROUP_SIZE_ALLOC = 32; + protected static final int THREAD_SCALE_FOR_LOGITS = 8; + + protected final S state; + protected final C config; + protected final W weights; + protected final KernelContext context; + protected final Model model; + protected final SchedulerType schedulerType; + + /** + * Constructor: validate quantization type, extract model components + */ + protected QuantizedLayerPlanner(S state, Model model) { + this.state = state; + this.model = model; + this.config = (C) model.configuration(); + this.weights = (W) model.weights(); + this.context = new KernelContext(); + this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); + validateQuantizationType(); + } + + /** + * Override in subclasses to validate correct quantization format. E.g., FP16LayerPlanner checks: weights instanceof FP16Weights + */ + protected abstract void validateQuantizationType(); + + /** + * Override in subclasses for model-specific initialization + */ + protected abstract void initializeLayerComponents(); + + // Common helper methods for all quantizations + protected C getConfig() { + return config; + } + + protected W getWeights() { + return weights; + } + + protected S getState() { + return state; + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java new file mode 100644 index 00000000..0480d513 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java @@ -0,0 +1,27 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; + +public class LlamaFP16LayerPlanner extends FP16LayerPlanner { + + public LlamaFP16LayerPlanner(LlamaState state, Model model) { + super(state, model); + validateQuantizationType(); + setupTornadoForwardPlan(); + } + + @Override + protected void initializeLayerComponents() { + this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); + this.ffnLayers = new LlamaFP16FFNLayers("llamaFFN", this.state, this.weights, this.config, this.schedulerType); + this.logitsLayer = new LogitsFP16Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + } + +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java new file mode 100644 index 00000000..b1f41515 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java @@ -0,0 +1,34 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; + +import org.beehive.gpullama3.inference.state.Phi3State; +import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.phi3.Phi3Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Phi3FP16FFNLayers; + +/** + * Phi3FP16LayerPlanner: Phi3 model with FP16 weights. + * + * Follows the same pattern as Qwen3FP16LayerPlanner but with: - Phi3-specific FFN layers (combined QKV + gate/up FFN) - Phi3TornadoWeights - Phi3Configuration + * + * Inherits from FP16LayerPlanner + */ +public class Phi3FP16LayerPlanner extends FP16LayerPlanner { + + public Phi3FP16LayerPlanner(Phi3State state, Model model) { + super(state, model); + validateQuantizationType(); + setupTornadoForwardPlan(); + } + + @Override + protected void initializeLayerComponents() { + this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); + this.ffnLayers = new Phi3FP16FFNLayers("phi3FFN", this.state, this.weights, this.config, this.schedulerType); + this.logitsLayer = new LogitsFP16Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(),this.schedulerType); + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java new file mode 100644 index 00000000..b87dafd8 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java @@ -0,0 +1,33 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; + +import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen2FP16FFNLayers; + +/** + * Qwen2FP16LayerPlanner: Qwen2 model with FP16 weights. + * + * Follows the same pattern as LlamaFP16LayerPlanner but with: - Qwen2-specific FFN layers (supports GQA with bias terms) - Qwen2TornadoWeights - Qwen2Configuration + * + * Inherits from FP16LayerPlanner + */ +public class Qwen2FP16LayerPlanner extends FP16LayerPlanner { + + public Qwen2FP16LayerPlanner(Qwen2State state, Model model) { + super(state, model); + validateQuantizationType(); + setupTornadoForwardPlan(); + } + + @Override + protected void initializeLayerComponents() { + this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); + this.ffnLayers = new Qwen2FP16FFNLayers("qwen2FFN", this.state, this.weights, this.config, this.schedulerType); + this.logitsLayer = new LogitsFP16Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java new file mode 100644 index 00000000..ef3dcee4 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java @@ -0,0 +1,34 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen3FP16FFNLayers; + +/** + * Qwen3FP16LayerPlanner: Qwen3 model with FP16 weights. + * + * Follows the same pattern as LlamaFP16LayerPlanner but with: - Qwen3-specific FFN layers (supports GQA) - Qwen3TornadoWeights - Qwen3Configuration + * + * Inherits from FP16LayerPlanner + */ +public class Qwen3FP16LayerPlanner extends FP16LayerPlanner { + + public Qwen3FP16LayerPlanner(Qwen3State state, Model model) { + super(state, model); + validateQuantizationType(); + setupTornadoForwardPlan(); + } + + @Override + protected void initializeLayerComponents() { + this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); + this.ffnLayers = new Qwen3FP16FFNLayers("qwen3FFN", this.state, this.weights, this.config, this.schedulerType); + this.logitsLayer = new LogitsFP16Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + } + +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java new file mode 100644 index 00000000..2560d8d7 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java @@ -0,0 +1,27 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LlamaQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; + +public class LlamaQ8_0LayerPlanner extends Q8_0LayerPlanner { + + public LlamaQ8_0LayerPlanner(LlamaState state, Model model) { + super(state, model); + validateQuantizationType(); + setupTornadoForwardPlan(); + } + + @Override + protected void initializeLayerComponents() { + this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); + this.ffnLayers = new LlamaQ8_0FFNLayers("llamaFFN", this.state, this.weights, this.config, this.schedulerType); + this.logitsLayer = new LogitsQ8_0Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + } + +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java new file mode 100644 index 00000000..dfa0ec0e --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java @@ -0,0 +1,35 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; + +import org.beehive.gpullama3.inference.state.Phi3State; +import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.phi3.Phi3Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Phi3Q8_0FFNLayers; + +/** + * Phi3Q8_0LayerPlanner: Phi3 model with Q8_0-quantized weights. + * + * Follows the same pattern as Qwen3Q8_0LayerPlanner but with: - Phi3-specific FFN layers (combined QKV + gate/up FFN) - Phi3TornadoWeights (8-bit integer quantization) - Phi3Configuration - 2x + * memory compression vs FP16 + * + * Inherits from Q8_0LayerPlanner + */ +public class Phi3Q8_0LayerPlanner extends Q8_0LayerPlanner { + + public Phi3Q8_0LayerPlanner(Phi3State state, Model model) { + super(state, model); + validateQuantizationType(); + setupTornadoForwardPlan(); + } + + @Override + protected void initializeLayerComponents() { + this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); + this.ffnLayers = new Phi3Q8_0FFNLayers("phi3FFN", this.state, this.weights, this.config, this.schedulerType); + this.logitsLayer = new LogitsQ8_0Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java new file mode 100644 index 00000000..34cb1a42 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java @@ -0,0 +1,35 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; + +import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen2Q8_0FFNLayers; + +/** + * Qwen2Q8_0LayerPlanner: Qwen2 model with Q8_0-quantized weights. + * + * Follows the same pattern as LlamaQ8_0LayerPlanner but with: - Qwen2-specific FFN layers (supports GQA with bias terms) - Qwen2TornadoWeights (8-bit integer quantization) - Qwen2Configuration - + * 2x memory compression vs FP16 + * + * Inherits from Q8_0LayerPlanner + */ +public class Qwen2Q8_0LayerPlanner extends Q8_0LayerPlanner { + + public Qwen2Q8_0LayerPlanner(Qwen2State state, Model model) { + super(state, model); + validateQuantizationType(); + setupTornadoForwardPlan(); + } + + @Override + protected void initializeLayerComponents() { + this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); + this.ffnLayers = new Qwen2Q8_0FFNLayers("qwen2FFN", this.state, this.weights, this.config, this.schedulerType); + this.logitsLayer = new LogitsQ8_0Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java new file mode 100644 index 00000000..fb4d4ef3 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java @@ -0,0 +1,34 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen3Q8_0FFNLayers; + +/** + * Qwen3Q8_0LayerPlanner: Qwen3 model with Q8_0-quantized weights. + * + * Follows the same pattern as LlamaQ8_0LayerPlanner but with: - Qwen3-specific FFN layers (supports GQA) - Qwen3TornadoWeights (8-bit integer quantization) - Qwen3Configuration - 2x memory + * compression vs FP16 + * + * Inherits from Q8_0LayerPlanner + */ +public class Qwen3Q8_0LayerPlanner extends Q8_0LayerPlanner { + + public Qwen3Q8_0LayerPlanner(Qwen3State state, Model model) { + super(state, model); + validateQuantizationType(); + setupTornadoForwardPlan(); + } + + @Override + protected void initializeLayerComponents() { + this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config); + this.ffnLayers = new Qwen3Q8_0FFNLayers("qwen3FFN", this.state, this.weights, this.config, this.schedulerType); + this.logitsLayer = new LogitsQ8_0Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(),this.schedulerType); + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/FP16LayerPlanner.java new file mode 100644 index 00000000..9be5e08b --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/FP16LayerPlanner.java @@ -0,0 +1,90 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.quantization; + +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizedLayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; + +import java.util.ArrayList; +import java.util.List; + +/** + * Base for all FP16-quantized layer planners. + * + * Subclasses: LlamaFP16LayerPlanner, Qwen2FP16LayerPlanner, etc. + * + * FP16 Specific: - Uses half-precision floating point kernels - Weights: weights.xxxHalfFloat arrays - Compute: 2x faster than FP32 on modern GPUs + */ +public abstract class FP16LayerPlanner extends QuantizedLayerPlanner { + + protected Activation activationLayer; + protected AbstractFFNLayers ffnLayers; + protected LogitsFP16Layer logitsLayer; + + protected List immutableTaskGraphs; + protected GridScheduler gridScheduler ; + + protected FP16LayerPlanner(S state, Model model) { + super(state, model); + initializeLayerComponents(); + } + + @Override + protected void validateQuantizationType() { + if (this.weights.getWeightType() != GGMLType.F16) { + throw new IllegalArgumentException("FP16LayerPlanner requires GGMLType.F16, got: " + this.weights.getWeightType()); + } + } + + @Override + protected void initializeLayerComponents() { + } + + protected final void setupTornadoForwardPlan() { + List allTaskGraphs = new ArrayList<>(); + GridScheduler masterScheduler = new GridScheduler(); + + // 1. Activation layer (common to all models) + allTaskGraphs.add(activationLayer.getImmutableTaskGraph()); + activationLayer.updateGridScheduler(masterScheduler); + + // 2. FFN layers (N transformer layers - model-specific) + allTaskGraphs.addAll(ffnLayers.getFfnLayerTaskGraphs()); + ffnLayers.updateGridScheduler(masterScheduler); + + // 3. Logits layer (common to all models) + allTaskGraphs.add(logitsLayer.getTaskGraph().snapshot()); + logitsLayer.updateGridScheduler(masterScheduler); + + // Cache for future retrievals + this.immutableTaskGraphs = allTaskGraphs; + this.gridScheduler = masterScheduler; + } + + /** + * Returns cached task graphs (used by hardware strategy pattern). + * + * Removed from all model-specific planners - centralized here. + */ + public final List getImmutableTaskGraphs() { + return this.immutableTaskGraphs; + } + + /** + * Returns cached scheduler (used by hardware strategy pattern). + * + * Removed from all model-specific planners - centralized here. + */ + @Override + public final GridScheduler getGridScheduler() { + return this.gridScheduler; + } + +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java new file mode 100644 index 00000000..f10f9686 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/quantization/Q8_0LayerPlanner.java @@ -0,0 +1,93 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.quantization; + +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizedLayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; + +import java.util.ArrayList; +import java.util.List; + +/** + * Base for all Q8_0-quantized layer planners. + * + * Subclasses: LlamaQ8_0LayerPlanner, Qwen2Q8_0LayerPlanner, etc. + * + * Q8_0 Specific: - Uses 8-bit integer quantization with uniform scaling per 32-element block - Weights: weights.xxxByteArray arrays - Compute: dequantize on-the-fly during matmul - Memory: 2x + * compression vs FP16 + */ +public abstract class Q8_0LayerPlanner extends QuantizedLayerPlanner { + + protected Activation activationLayer; + protected AbstractFFNLayers ffnLayers; + protected LogitsQ8_0Layer logitsLayer; + + // Cache for task graphs and scheduler (set once, reused) + protected List cachedTaskGraphs; + protected GridScheduler cachedScheduler; + + protected Q8_0LayerPlanner(S state, Model model) { + super(state, model); + initializeLayerComponents(); + } + + @Override + protected void validateQuantizationType() { + if (this.weights.getWeightType() != GGMLType.Q8_0) { + throw new IllegalArgumentException("Q8_0LayerPlanner requires GGMLType.Q8_0, got: " + this.weights.getWeightType()); + } + } + + @Override + protected void initializeLayerComponents() { + // Override in subclasses (LlamaQ8_0LayerPlanner, etc.) + } + + protected final void setupTornadoForwardPlan() { + List allTaskGraphs = new ArrayList<>(); + GridScheduler masterScheduler = new GridScheduler(); + + // 1. Activation layer (common to all models) + allTaskGraphs.add(activationLayer.getImmutableTaskGraph()); + activationLayer.updateGridScheduler(masterScheduler); + + // 2. FFN layers (N transformer layers - model-specific) + allTaskGraphs.addAll(ffnLayers.getFfnLayerTaskGraphs()); + ffnLayers.updateGridScheduler(masterScheduler); + + // 3. Logits layer (common to all models) + allTaskGraphs.add(logitsLayer.getTaskGraph().snapshot()); + logitsLayer.updateGridScheduler(masterScheduler); + + // Cache for future retrievals + this.cachedTaskGraphs = allTaskGraphs; + this.cachedScheduler = masterScheduler; + } + + /** + * Returns cached task graphs (used by hardware strategy pattern). + * + * Removed from all model-specific planners - centralized here. + */ + public final List getImmutableTaskGraphs() { + return this.cachedTaskGraphs; + } + + /** + * Returns cached scheduler (used by hardware strategy pattern). + * + * Removed from all model-specific planners - centralized here. + */ + @Override + public final GridScheduler getGridScheduler() { + return this.cachedScheduler; + } + +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerDetectionService.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerDetectionService.java new file mode 100644 index 00000000..5a81caa8 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerDetectionService.java @@ -0,0 +1,27 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.strategy; + +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.ModelType; +import uk.ac.manchester.tornado.api.TornadoRuntime; +import uk.ac.manchester.tornado.api.runtime.TornadoRuntimeProvider; + +import java.util.Locale; + +public class SchedulerDetectionService { + + + public static SchedulerType determineSchedulerType(Model model) { + TornadoRuntime tornadoRuntime = TornadoRuntimeProvider.getTornadoRuntime(); + String platformName = tornadoRuntime.getBackend(0) + .getDefaultDevice() + .getPlatformName() + .toLowerCase(Locale.ROOT); + + boolean isNvidia = platformName.contains("nvidia") || + platformName.contains("cuda") || + platformName.contains("ptx"); + boolean isNotMistral = model.getModelType() != ModelType.MISTRAL; + + return (isNvidia && isNotMistral) ? SchedulerType.NVIDIA : SchedulerType.NON_NVIDIA; + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerType.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerType.java new file mode 100644 index 00000000..28b568a2 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/strategy/SchedulerType.java @@ -0,0 +1,5 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.strategy; + +public enum SchedulerType { + NVIDIA, NON_NVIDIA +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java new file mode 100644 index 00000000..3b0620c6 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractFFNLayers.java @@ -0,0 +1,78 @@ +package org.beehive.gpullama3.tornadovm.layers; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; + +import java.util.List; + +/** + * Abstract base class for all FFN (Feed-Forward Network) layer implementations. + * + * Extends AbstractLayer and adds FFN-specific methods: - getFfnLayerTaskGraphs(): Returns task graphs for all transformer layers - getLastTaskGraphID(): Tracks the ID of the last task graph + * + * All model-specific FFN layers extend this: - LlamaFP16FFNLayers, Qwen2FP16FFNLayers, Qwen3FP16FFNLayers, Phi3FP16FFNLayers - LlamaQ8_0FFNLayers, Qwen2Q8_0FFNLayers, Qwen3Q8_0FFNLayers, + * Phi3Q8_0FFNLayers + * + * Used by FP16LayerPlanner and Q8_0LayerPlanner template methods for type-safe polymorphic access to any FFN layer implementation. + */ +public abstract class AbstractFFNLayers extends AbstractLayer { + + protected String lastTaskGraphID; + protected final SchedulerType schedulerType; + + + /** + * Constructor for FFN layers. + * + * @param taskGraphName + * Name for the task graph + * @param state + * Runtime state (LlamaState, Qwen2State, etc.) + * @param weights + * Model weights (FP16Weights, Q8_0Weights, etc.) + * @param config + * Model configuration + */ + protected AbstractFFNLayers(String taskGraphName, State state, Weights weights, Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config); + this.schedulerType = schedulerType; + } + + /** + * Returns all task graphs for the FFN layers. + * + * For a model with N transformer layers, this returns N ImmutableTaskGraphs, one for each layer (containing RMSNorm, Attention, FFN computations). + * + * @return List of immutable task graphs (one per transformer layer) + */ + public abstract List getFfnLayerTaskGraphs(); + + /** + * Get the ID of the last task graph. + * + * Used by LogitsLayer to know where to attach the final logits computation. The last transformer layer's task graph ID is needed to chain the logits computation after all FFN layers complete. + * + * @return Task graph ID of the last FFN layer + */ + @Override + public String getLastTaskGraphID() { + return lastTaskGraphID; + } + + /** + * Configures the attention mechanism based on hardware scheduler type. + * + * - NVIDIA hardware: Uses Flash Attention for optimized performance + * - NON_NVIDIA hardware: Uses parallel head processing + * + * This method should be called during task graph setup in subclasses. + * + * @return true if final normalization step should be used (NON_NVIDIA), false otherwise + */ + protected boolean shouldUseFinalNormalization() { + return schedulerType == SchedulerType.NON_NVIDIA; + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java new file mode 100644 index 00000000..6578777f --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/AbstractLayer.java @@ -0,0 +1,73 @@ +package org.beehive.gpullama3.tornadovm.layers; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.Configuration; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; + +import java.util.ArrayList; +import java.util.List; + +/** + * Minimal base with common fields/utilities so subclasses compile cleanly. Adjust or remove fields if they already exist in your project. + */ +public abstract class AbstractLayer { + + /** Common constants used in tasks & worker-grid sizing. */ + protected static final int LOCAL_WORK_GROUP_SIZE_ALLOC = 32; + protected static final int THREAD_SCALE_FOR_LOGITS = 8; + protected static String lastTaskGraphID; + protected final Weights weights; + protected final Configuration config; + /** Often a small context/config buffer passed into kernels. Use your real type if available. */ + protected final KernelContext context = new KernelContext(); + /** Collected snapshots for scheduling / debugging. */ + protected final List taskGraphs = new ArrayList<>(); + /** Optional: track the "main" task graph for the layer if one exists. */ + protected TaskGraph taskGraph; + /** Shared runtime objects (exposed because kernels expect them). */ + protected State state; + + protected AbstractLayer(String taskGraphName, State state, Weights weights, Configuration config) { + this.taskGraph = null; + this.state = state; + this.weights = weights; + this.config = config; + } + + @SuppressWarnings("unchecked") + protected static T requireWeightsType(Object weights, Class expectedType, String layerName, String layout) { + if (expectedType.isInstance(weights)) { + return (T) weights; + } + throw new IllegalArgumentException(layerName + " requires " + expectedType.getSimpleName() + " with " + layout + " layout"); + } + + public abstract GridScheduler updateGridScheduler(GridScheduler scheduler); + + public abstract GridScheduler getGridScheduler(); + + public abstract TaskGraph getTaskGraph(); + + public abstract ImmutableTaskGraph getImmutableTaskGraph(); + + /** Allow subclasses to override if they need custom transfers. */ + protected TaskGraph configureLayerDataTransfers(TaskGraph tg, int layerIndex) { + return tg; + } + + public String getLastTaskGraphID() { + return lastTaskGraphID; + } + + public void setupLastID(String taskGraphID) { + if (lastTaskGraphID == null) { + lastTaskGraphID = taskGraphID; + } else if (!lastTaskGraphID.equals(taskGraphID)) { + throw new IllegalStateException("Task graph IDs do not match: " + lastTaskGraphID + " vs " + taskGraphID); + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java new file mode 100644 index 00000000..16783829 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java @@ -0,0 +1,49 @@ +package org.beehive.gpullama3.tornadovm.layers; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +public class Activation extends AbstractLayer { + private final TaskGraph activationUpdate; + + public Activation(String taskGraphHandle, State state, Weights weights, Configuration config) { + super(taskGraphHandle, state, weights, config); + + // formatter:off + this.activationUpdate = new TaskGraph(taskGraphHandle).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX) + .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX).persistOnDevice(state.wrapX); + // formatter:on + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler scheduler) { + WorkerGrid singleWorker = WorkerGridFactory.createSingleWorker(); + scheduler.addWorkerGrid("activationUpdate.updateX", singleWorker); + return scheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return null; + } + + @Override + public TaskGraph getTaskGraph() { + return activationUpdate; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return activationUpdate.snapshot(); + } + +} + diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java new file mode 100644 index 00000000..96acd650 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -0,0 +1,178 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +public class LlamaFP16FFNLayers extends AbstractFFNLayers { + + TaskGraph ffnTaskGraphs; + GridScheduler scheduler; + List ffnLayerTaskGraphs; + public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Configuration config, SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + this.ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim()/2, 128); + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configKvDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configKvDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128); + + // Map workers to tasks + for (int i = 0; i < config.numberOfLayers(); i++) { + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + } + return tornadoForwardScheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return ffnTaskGraphs; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + List setupFFNLayered() { + state.temp.init(0.0f); + state.tempFFN.init(0.0f); + var numLayers = config.numberOfLayers(); + + return IntStream.range(0, numLayers) + .mapToObj(i -> { + var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); + if (i == numLayers - 1) setupLastID(ffnLayer.getTaskGraphName()); + return ffnLayer.snapshot(); + }) + .toList(); + } + + TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { + var layerTaskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w2Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + unifiedLayer + .task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, + config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) + .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapV, weights.wvLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) + .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), + layerIndex, config.contextLength()); + configureAttention(unifiedLayer, layerIndex); + unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) + .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(), + config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); + return unifiedLayer; + } + + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + // First layer: Transfer initial data to device (one-time transfer) + if (layerIndex == 0) { + // Transfer all attention-related data: query, key, value matrices and their caches + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); // + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // + context, state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb); // + } else { + // Subsequent layers: Consume data already on device from previous layer + unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb, // + state.positionHolder // + ); + } + return unifiedLayer; + } + + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { + if (schedulerType == SchedulerType.NVIDIA) { + return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, config.contextLength()); + } else { + return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.contextLength(), state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java new file mode 100644 index 00000000..a674c1c5 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsFP16Layer.java @@ -0,0 +1,88 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractLayer; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.WorkerGrid1D; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +public class LogitsFP16Layer extends AbstractLayer { + + private String lastTaskGraphID; + private TaskGraph logitsTaskGraph; + private ImmutableTaskGraph immutableLogitsGraph; + private GridScheduler scheduler; + private SchedulerType schedulerType; + + public LogitsFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config); + this.lastTaskGraphID = lastTaskGraphID; + state.tempLogits.init(0.0f); + var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); + this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); + this.schedulerType = schedulerType; + } + + /** + * Builds the logits computation graph. + */ + private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { + TaskGraph logits = new TaskGraph("logits"); + logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits) + .transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), weights.rms_final_weight_as_floatArray.asFloatArray()) + .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (schedulerType == SchedulerType.NON_NVIDIA) { + logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps()); + } + logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) + .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapX, state.wrapLogits, weights.wclsByteArray.asHalfFloatArray(), config.dim(), config.vocabularySize(), + LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); + logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); + return logits; + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid logitsRMS; + if (weights instanceof Qwen2TornadoWeights) { + logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 32); + } else { + logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + } + + var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; + WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); + vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); + + tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker); + tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); + return tornadoForwardScheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return logitsTaskGraph; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return immutableLogitsGraph; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java new file mode 100644 index 00000000..75f9f531 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -0,0 +1,329 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.Phi3State; +import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; +import org.beehive.gpullama3.model.phi3.Phi3Configuration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.ArrayList; +import java.util.List; + +/** + * Phi3FP16FFNLayers: FP16 FFN layers for Phi3 with Group Query Attention (GQA) support. + * + * Key Differences from Qwen2/Qwen3: + * - Uses combined QKV matrix (wqkv) instead of separate Q, K, V matrices + * - Includes splitQKV task to separate combined buffer + * - Uses ropeRotationPhi3 kernel for position embeddings + * - FFN uses single wUp matrix that outputs both Gate and Up (2 * hiddenDim) + * - Includes splitGateUpAndSiLU task for FFN activation + * - Uses wDown for final FFN projection + * - No Q, K, V bias terms + * + * Works directly with Phi3State to access and mutate Phi3-specific state fields. + */ +public class Phi3FP16FFNLayers extends AbstractFFNLayers { + + TaskGraph ffnLayerTaskGraph; + GridScheduler scheduler; + List ffnLayerTaskGraphs; + + // Typed references to Phi3-specific state and config + private final Phi3State phi3State; + private final Phi3Configuration phi3Config; + + // Phi3-specific dimension for combined QKV buffer + private final int opSize; + + public Phi3FP16FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeights weights, Phi3Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config,schedulerType); + this.phi3State = state; + this.phi3Config = config; + + // Ensure we have Phi3-specific weights + if (!(weights instanceof Phi3TornadoWeights phi3Weights)) { + throw new IllegalArgumentException("Phi3FP16FFNLayers requires Phi3TornadoWeights with TornadoTensor layout"); + } + + // Calculate opSize for combined QKV buffer + // opSize = num_heads * head_dim + 2 * (num_key_value_heads * head_dim) = dim + 2 * kvDim + this.opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); + ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { + // RMS norm worker + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), state.localSize); + + // Combined QKV matmul worker + int matmulQkvGlobal = opSize * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid matmulQkvRowMajorWorker = WorkerGridFactory.genericWorker(matmulQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + // RoPE worker (2D: heads x embedding_head/2) + int ic = config.headSize() / 2; + WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(config.numberOfHeads(), config.headSize()); + + // Copy to cache worker + WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.kvDim(), 32); + + // Parallel attention worker + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + + // Matmul1 worker (output projection) + int matmul1Global = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid matmul1Worker = WorkerGridFactory.genericWorker(matmul1Global, LOCAL_WORK_GROUP_SIZE_ALLOC); + + // FFN workers + int ffnUpGlobal = (2 * config.hiddenDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid ffnUpWorker = WorkerGridFactory.genericWorker(ffnUpGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int ffnDownGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid ffnDownWorker = WorkerGridFactory.genericWorker(ffnDownGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Map workers to tasks for each layer + for (int i = 0; i < config.numberOfLayers(); i++) { + gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".qkvmatmul", matmulQkvRowMajorWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker); + gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".wGateUp", ffnUpWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".wDown", ffnDownWorker); + } + + return gridScheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return ffnLayerTaskGraph; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + /** + * Setup all FFN layers for all transformer layers + */ + List setupFFNLayered() { + List ffnGraphs = new ArrayList<>(); + + // Initialize buffers using Phi3State directly + phi3State.temp.init(0.0f); + phi3State.tempFFN.init(0.0f); + + for (int layerIndex = 0; layerIndex < phi3Config.numberOfLayers(); layerIndex++) { + TaskGraph ffnLayer = setupSinglePhi3FFNLayer((Phi3TornadoWeights) weights, layerIndex); + if (layerIndex == phi3Config.numberOfLayers() - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + ffnGraphs.add(ffnLayer.snapshot()); + } + + return ffnGraphs; + } + + /** + * Setup a single transformer layer for Phi3 with combined QKV and gate/up FFN + */ + TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { + + TaskGraph unifiedLayer = new TaskGraph("layer_" + layerIndex); + unifiedLayer.consumeFromDevice(phi3State.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + // Copy-in weights per layer for batched-layered layout + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqkvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.wUpLayered[layerIndex].asHalfFloatArray(), + weights.wDownLayered[layerIndex].asHalfFloatArray() + ); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + // RMSNorm for attention input + unifiedLayer.task("reductionsOneBlock", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, + phi3State.temp, + phi3State.wrapX, + phi3Config.dim(), + phi3Config.rmsNormEps(), + phi3State.localSize) + .task("mapContext", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, + phi3State.wrapXb, + phi3State.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + phi3State.temp); + + // Combined QKV projection + unifiedLayer.task("qkvmatmul", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, + phi3State.wrapXb, + phi3State.wrapQkv, + weights.wqkvLayered[layerIndex].asHalfFloatArray(), + phi3Config.dim(), + opSize, + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("splitQKV", + TransformerComputeKernelsLayered::splitQKV, + phi3State.wrapQkv, + phi3State.wrapQ, + phi3State.wrapK, + phi3State.wrapV, + phi3Config.dim(), + phi3Config.headSize() * phi3Config.numberOfKeyValueHeads()); + + // RoPE rotation (Phi3-specific kernel) + unifiedLayer.task("rope", + TransformerComputeKernelsLayered::ropeRotationPhi3, + context, + phi3State.positionHolder, + phi3State.wrapQ, + phi3State.wrapK, + phi3Config.kvDim(), + phi3Config.headSize()); + + // Copy to caches + unifiedLayer.task("copyToCaches", + TransformerComputeKernelsLayered::copyToCache, + phi3State.wrapKeyCache, + phi3State.wrapK, + phi3State.wrapValueCache, + phi3State.wrapV, + phi3State.positionHolder, + phi3Config.kvDim(), + layerIndex, + phi3Config.contextLength()); + + // Parallel attention + unifiedLayer.task("parallel-attention", + TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, + phi3State.wrapQ, + phi3State.wrapKeyCache, + phi3State.wrapValueCache, + phi3State.wrapXb, + phi3Config.numberOfHeads(), + phi3Config.headSize(), + phi3Config.kvDim(), + phi3Config.kvMul(), + phi3State.positionHolder, + layerIndex, + phi3Config.contextLength()); + + // Output projection + unifiedLayer.task("matmul1", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, + phi3State.wrapXb, + phi3State.wrapX, + weights.woLayered[layerIndex].asHalfFloatArray(), + phi3Config.dim(), + phi3Config.dim(), + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // FFN section: RMSNorm + unifiedLayer.task("reductionsOneBlockFFN", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, + phi3State.tempFFN, + phi3State.wrapX, + phi3Config.dim(), + phi3Config.rmsNormEps(), + phi3State.localSize) + .task("mapContextFFN", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, + phi3State.wrapXb, + phi3State.wrapX, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + phi3State.tempFFN); + + // FFN: combined Up and Gate projection (outputs 2 * hiddenDim) + unifiedLayer.task("wGateUp", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, + phi3State.wrapXb, + phi3State.wrapHb, + weights.wUpLayered[layerIndex].asHalfFloatArray(), + phi3Config.dim(), + 2 * phi3Config.hiddenDim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("gateUpSiLU", + TransformerComputeKernelsLayered::splitGateUpAndSiLU, + phi3State.wrapHb, + phi3State.wrapHbG, + phi3State.wrapHbU, + phi3Config.hiddenDim()); + + // FFN: Down projection with residual + unifiedLayer.task("wDown", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, + phi3State.wrapHbU, + phi3State.wrapX, + weights.wDownLayered[layerIndex].asHalfFloatArray(), + phi3Config.hiddenDim(), + phi3Config.dim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .persistOnDevice( + phi3State.wrapX + ); + return unifiedLayer; + } + + /** + * Configure data transfers for first and subsequent layers + */ + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + // First layer: Transfer initial data to device (one-time transfer) + if (layerIndex == 0) { + // Transfer all attention-related data: query, key, value matrices and their caches + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); // + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // + context, state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb, // + phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); // + } else { + // Subsequent layers: Consume data already on device from previous layer + unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb, // + state.positionHolder, // / + phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); + } + return unifiedLayer; + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java new file mode 100644 index 00000000..858848ea --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -0,0 +1,238 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; +import org.beehive.gpullama3.tornadovm.kernels.Qwen2Kernels; +import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.WorkerGrid1D; +import uk.ac.manchester.tornado.api.WorkerGrid2D; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.ArrayList; +import java.util.List; + +/** + * Qwen2FP16FFNLayers: FP16 FFN layers for Qwen2 with Group Query Attention (GQA) support. + * + * Key Differences from Qwen3: - No tempQcur/tempKcur fields in Qwen2State - Includes bias terms for Q, K, V projections - Standard GQA (no parallel offset RMSNorm) - Uses + * Qwen2Kernels::processHeadsFlashAttention for attention computation - Uses Qwen3Kernels::ropeRotation for position embeddings - Simpler matrix dimensions (uses config.dim() and config.kvDim() + * directly) + * + * Works directly with Qwen2State to access and mutate Qwen2-specific state fields. + */ +public class Qwen2FP16FFNLayers extends AbstractFFNLayers { + + // Typed references to Qwen2-specific state and config + private final Qwen2State qwen2State; + private final Qwen2Configuration qwen2Config; + TaskGraph ffnLayerTaskGraph; + GridScheduler scheduler; + List ffnLayerTaskGraphs; + + public Qwen2FP16FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWeights weights, Qwen2Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); + this.qwen2State = state; + this.qwen2Config = config; + ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + int h = config.numberOfHeads(); + int ic = config.headSize() / 2; + WorkerGrid ropeWorker = new WorkerGrid2D(h, ic); + ropeWorker.setGlobalWork(h, ic, 1); + ropeWorker.setLocalWork(1, 1, 1); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = new WorkerGrid1D(configDimRowMajorGlobal); + configDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); + + int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configKvDimRowMajorGlobalWorker = new WorkerGrid1D(configKvDimRowMajorGlobal); + configKvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); + + WorkerGrid qBiasWorker = new WorkerGrid1D(config.dim()); + qBiasWorker.setGlobalWork(config.dim(), 1, 1); + qBiasWorker.setLocalWork(config.dim() / 8, 1, 1); + WorkerGrid kvBiasWorker = new WorkerGrid1D(config.kvDim()); + kvBiasWorker.setGlobalWork(config.kvDim(), 1, 1); + kvBiasWorker.setLocalWork(32, 1, 1); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = new WorkerGrid1D(configHiddenDimRowMajor); + configHiddenDimRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); + + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 32); + + // Parallel attention worker configuration + // Calculate optimal local work size based on head dimension + int optimalLocalSize = Math.min(config.headSize(), 64); // Start with 64 threads per head + if (config.headSize() % optimalLocalSize != 0) { + // Find largest divisor of headSize <= 64 + for (int size = 64; size >= 1; size--) { + if (config.headSize() % size == 0) { + optimalLocalSize = size; + break; + } + } + } + + WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads()); + parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * optimalLocalSize, 1, 1); + parallelAttentionWorker.setLocalWork(optimalLocalSize, 1, 1); + + WorkerGrid copyToCachesWorker = new WorkerGrid1D(config.kvDim()); + copyToCachesWorker.setGlobalWork(config.kvDim(), 1, 1); + copyToCachesWorker.setLocalWork(32, 1, 1); // Set local work size to 32 (for copying to caches) + + // Map workers to tasks + for (int i = 0; i < config.numberOfLayers(); i++) { + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qbias", qBiasWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kbias", kvBiasWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vbias", kvBiasWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + } + return tornadoForwardScheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return ffnLayerTaskGraph; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + /** + * Setup all FFN layers for all transformer layers + */ + List setupFFNLayered() { + List ffnGraphs = new ArrayList<>(); + + qwen2State.temp.init(0.0f); + qwen2State.tempFFN.init(0.0f); + + for (int layerIndex = 0; layerIndex < qwen2Config.numberOfLayers(); layerIndex++) { + TaskGraph ffnLayer = setupSingleQwen2FFNLayer((Qwen2TornadoWeights) weights, layerIndex); + if (layerIndex == qwen2Config.numberOfLayers() - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + ffnGraphs.add(ffnLayer.snapshot()); + } + + return ffnGraphs; + } + + /** + * Setup a single transformer layer for Qwen2 with GQA + */ + TaskGraph setupSingleQwen2FFNLayer(Qwen2TornadoWeights weights, int layerIndex) { + var taskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(taskGraphName); + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // + weights.rms_att_weightLayered[layerIndex].asFloatArray(), // + weights.wqLayered[layerIndex].asHalfFloatArray(), // + weights.wkLayered[layerIndex].asHalfFloatArray(), // + weights.wvLayered[layerIndex].asHalfFloatArray(), // + weights.woLayered[layerIndex].asHalfFloatArray(), // + weights.q_biasLayered[layerIndex].asFloatArray(), // + weights.k_biasLayered[layerIndex].asFloatArray(), // + weights.v_biasLayered[layerIndex].asFloatArray(), // + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // + weights.w1Layered[layerIndex].asHalfFloatArray(), // + weights.w2Layered[layerIndex].asHalfFloatArray(), // + weights.w3Layered[layerIndex].asHalfFloatArray()); // + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); // + + unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen2State.temp, qwen2State.wrapX, config.dim(), config.rmsNormEps(), + qwen2State.localSize) + .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen2State.wrapXb, qwen2State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), + qwen2State.temp) + .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen2State.wrapXb, qwen2State.wrapV, weights.wvLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(), + LOCAL_WORK_GROUP_SIZE_ALLOC).task("qbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapQ, weights.q_biasLayered[layerIndex].asFloatArray(), config.dim()) + .task("kbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapK, weights.k_biasLayered[layerIndex].asFloatArray(), config.kvDim()) + .task("vbias", TransformerComputeKernelsLayered::addInPlace, qwen2State.wrapV, weights.v_biasLayered[layerIndex].asFloatArray(), config.kvDim()) + .task("rope", Qwen3Kernels::ropeRotation, context, qwen2State.positionHolder, qwen2State.wrapQ, qwen2State.wrapK, config.numberOfKeyValueHeads(), config.headSize()) + .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, qwen2State.wrapKeyCache, qwen2State.wrapK, qwen2State.wrapValueCache, qwen2State.wrapV, qwen2State.positionHolder, + config.kvDim(), layerIndex, config.contextLength()) + .task("parallel-attention", Qwen2Kernels::processHeadsFlashAttention, context, qwen2State.wrapQ, qwen2State.wrapKeyCache, qwen2State.wrapValueCache, qwen2State.wrapXb, + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), qwen2State.positionHolder, layerIndex, config.contextLength()) + .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen2State.wrapXb, qwen2State.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), + config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen2State.tempFFN, qwen2State.wrapX, config.dim(), config.rmsNormEps(), + qwen2State.localSize) + .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen2State.wrapXb, qwen2State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + qwen2State.tempFFN) + .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, qwen2State.wrapXb, qwen2State.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen2State.wrapHb, qwen2State.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); + + return unifiedLayer; + } + + /** + * Configure data transfers for first and subsequent layers + */ + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + // First layer: Transfer initial data to device (one-time transfer) + if (layerIndex == 0) { + // Transfer all attention-related data: query, key, value matrices and their caches + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen2State.positionHolder, qwen2State.temp, qwen2State.tempFFN); // + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // + context, qwen2State.wrapXb, qwen2State.wrapXb2, // + qwen2State.wrapQ, qwen2State.wrapK, qwen2State.wrapV, // + qwen2State.wrapKeyCache, qwen2State.wrapValueCache, // + qwen2State.wrapAtt, qwen2State.wrapHb); // + } else { + // Subsequent layers: Consume data already on device from previous layer + unifiedLayer.consumeFromDevice( // + context, qwen2State.wrapXb, qwen2State.wrapXb2, // + qwen2State.wrapQ, qwen2State.wrapK, qwen2State.wrapV, // + qwen2State.wrapKeyCache, qwen2State.wrapValueCache, // + qwen2State.wrapAtt, qwen2State.wrapHb, // + qwen2State.positionHolder // + ); + } + return unifiedLayer; + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java new file mode 100644 index 00000000..379921c3 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -0,0 +1,290 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.ArrayList; +import java.util.List; + +/** + * Qwen3FP16FFNLayers: FP16 FFN layers for Qwen3 with Group Query Attention (GQA) support. + * + * Key Differences from Llama: - Supports GQA with separate KV heads (nHeadKv) - Uses Qwen3Kernels for RMSNorm with parallel offset - Custom RoPE rotation for Qwen3 - Different attention computation + * due to GQA structure + * + * Works directly with Qwen3State to access and mutate Qwen3-specific state fields like tempQcur and tempKcur. + */ +public class Qwen3FP16FFNLayers extends AbstractFFNLayers { + + // Typed references to Qwen3-specific state and config + private final Qwen3State qwen3State; + private final Qwen3Configuration qwen3Config; + // Qwen3-specific GQA parameters + private final int nHeadKv; + private final int nEmbdHeadK; + private final int nEmbdHeadV; + private final int nEmbdVGqa; + private final int nEmbdHead; + private final int nEmbdGqa; + private final int gqa; + TaskGraph ffnLayerTaskGraph; + GridScheduler scheduler; + List ffnLayerTaskGraphs; + + public Qwen3FP16FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWeights weights, Qwen3Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config,schedulerType); + this.qwen3State = state; + this.qwen3Config = config; + + // Initialize GQA parameters from Qwen3Config + this.nHeadKv = config.numberOfKeyValueHeads(); + this.nEmbdHeadK = config.numberOfHeadsKey(); + this.nEmbdHeadV = config.numberOfHeadsValue(); + this.nEmbdVGqa = nEmbdHeadV * nHeadKv; + this.nEmbdHead = nEmbdHeadV; + this.nEmbdGqa = nEmbdVGqa; + this.gqa = config.numberOfHeads() / config.numberOfKeyValueHeads(); + ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), state.localSize); + + // Q matmul worker (GQA: full query heads) + int matmulQGlobal = nEmbdHeadK * config.numberOfHeads() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid matmulQRowMajorWorker = WorkerGridFactory.genericWorker(matmulQGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + // KV matmul worker (GQA: reduced KV heads) + int matmulKVGlobal = nEmbdGqa * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid matmulKVRowMajorWorker = WorkerGridFactory.genericWorker(matmulKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Current embedding head worker + WorkerGrid curWorker = WorkerGridFactory.createRmsNormWorker(nEmbdHead, 128); + + // Q current worker + WorkerGrid qCurWorker = WorkerGridFactory.genericWorker(config.numberOfHeads() * nEmbdHead, nEmbdHead); + + // K current worker + WorkerGrid kCurWorker = WorkerGridFactory.genericWorker(config.numberOfKeyValueHeads() * nEmbdHead, nEmbdHead); + + // RoPE worker (2D: heads x embedding_head/2) + int ic = nEmbdHead / 2; + WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(config.numberOfHeads(), nEmbdHead); + + // Copy to cache worker + WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(nEmbdGqa, 128); + + // Parallel attention worker + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), nEmbdHead); + + // Matmul1 worker (output projection) + int matmul1Global = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid matmul1Worker = WorkerGridFactory.genericWorker(matmul1Global, LOCAL_WORK_GROUP_SIZE_ALLOC); + + // FFN workers + int fusedFFNW1W3Global = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedFFNW1W3Worker = WorkerGridFactory.genericWorker(fusedFFNW1W3Global, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int projectionTwoGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid projectionTwoWorker = WorkerGridFactory.genericWorker(projectionTwoGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Map workers to tasks for each layer + for (int i = 0; i < config.numberOfLayers(); i++) { + gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); + + gridScheduler.addWorkerGrid("layer_" + i + ".qmatmul", matmulQRowMajorWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".kmatmul", matmulKVRowMajorWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".vmatmul", matmulKVRowMajorWorker); + + gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Qcur", qCurWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Qcur", qCurWorker); + + gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Kcur", kCurWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Kcur", kCurWorker); + + gridScheduler.addWorkerGrid("layer_" + i + ".ropeRotation", ropeWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker); + + gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); + gridScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", fusedFFNW1W3Worker); + gridScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", projectionTwoWorker); + } + + return gridScheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return ffnLayerTaskGraph; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + /** + * Setup all FFN layers for all transformer layers + */ + List setupFFNLayered() { + List ffnGraphs = new ArrayList<>(); + qwen3State.temp.init(0.0f); + qwen3State.tempFFN.init(0.0f); + qwen3State.tempQcur.init(0.0f); + qwen3State.tempKcur.init(0.0f); + + for (int layerIndex = 0; layerIndex < qwen3Config.numberOfLayers(); layerIndex++) { + TaskGraph ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, layerIndex); + if (layerIndex == qwen3Config.numberOfLayers() - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + ffnGraphs.add(ffnLayer.snapshot()); + } + + return ffnGraphs; + } + + /** + * Setup a single transformer layer for Qwen3 with GQA + */ + TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) { + var taskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(taskGraphName); + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // + //Copy-in weights per layer for batched-layered layout + weights.rms_att_weightLayered[layerIndex].asFloatArray(), // + weights.wqLayered[layerIndex].asHalfFloatArray(), // + weights.wkLayered[layerIndex].asHalfFloatArray(), // + weights.wvLayered[layerIndex].asHalfFloatArray(), // + weights.woLayered[layerIndex].asHalfFloatArray(), // + //rms_att_KNormLayered + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), // + //rms_att_QNormLayered + weights.rms_att_QNormLayered[layerIndex].asFloatArray(), // + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // + weights.w1Layered[layerIndex].asHalfFloatArray(), // + weights.w2Layered[layerIndex].asHalfFloatArray(), // + weights.w3Layered[layerIndex].asHalfFloatArray() // + ); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen3State.temp, qwen3State.wrapX, // in + qwen3Config.dim(), qwen3Config.rmsNormEps(), qwen3State.localSize).task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen3State.wrapXb, // out + qwen3State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), qwen3State.temp); + + int qDim0 = nEmbdHeadK * qwen3Config.numberOfHeads(); + int kvDim0 = nEmbdGqa; + int qkvDim1 = qwen3Config.dim(); + unifiedLayer.task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen3State.wrapXb, qwen3State.wrapQ, // output + weights.wqLayered[layerIndex].asHalfFloatArray(), qkvDim1, qDim0, LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen3State.wrapXb, qwen3State.wrapK, // output + weights.wkLayered[layerIndex].asHalfFloatArray(), qkvDim1, kvDim0, LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, qwen3State.wrapXb, qwen3State.wrapV, // output + weights.wvLayered[layerIndex].asHalfFloatArray(), qkvDim1, kvDim0, LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Qcur rmsnorm + unifiedLayer.task("rmsnormReduction_Qcur", Qwen3Kernels::rmsnormWithParallelOffset, context, qwen3State.tempQcur, // output + qwen3State.wrapQ, // input + qwen3State.localSize, // currently 128, should be variable of global nEmbHead + nEmbdHead, // for normalization + qwen3Config.rmsNormEps()) // for normalization + .task("rmsnormMapIndexInPlace_Qcur", Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset, context, qwen3State.wrapQ, // output + weights.rms_att_QNormLayered[layerIndex].asFloatArray(), nEmbdHead, qwen3State.tempQcur); + + // Kcur rmsnorm + unifiedLayer.task("rmsnormReduction_Kcur", Qwen3Kernels::rmsnormWithParallelOffset, context, qwen3State.tempKcur, // output + qwen3State.wrapK, // input + qwen3State.localSize, // currently 128, should be variable of global nEmbHead + nEmbdHead, // for normalization + qwen3Config.rmsNormEps()) // for normalization + .task("rmsnormMapIndexInPlace_Kcur", Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset, context, qwen3State.wrapK, // output + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), nEmbdHead, qwen3State.tempKcur); + + // rope rotation task graph + unifiedLayer.task("ropeRotation", Qwen3Kernels::ropeRotation, context, qwen3State.positionHolder, qwen3State.wrapQ, // out + qwen3State.wrapK, // out + qwen3Config.numberOfKeyValueHeads(), nEmbdHead); + + unifiedLayer.task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, qwen3State.wrapKeyCache, // out + qwen3State.wrapK, // in + qwen3State.wrapValueCache, // out + qwen3State.wrapV, // in + qwen3State.positionHolder, nEmbdGqa, layerIndex, qwen3Config.contextLength()); + + unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttentionOpt, context, qwen3State.wrapQ, qwen3State.wrapKeyCache, qwen3State.wrapValueCache, + qwen3State.wrapXb, // out + qwen3Config.numberOfHeads(), nEmbdHead, nEmbdGqa, gqa, qwen3State.positionHolder, layerIndex, qwen3Config.contextLength()); + + unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen3State.wrapXb, // vector + qwen3State.wrapX, // out, should be [1024] + weights.woLayered[layerIndex].asHalfFloatArray(), // matrix + nEmbdHeadK * qwen3Config.numberOfHeads(), // dim1 = 2048 + qwen3Config.dim(), // dim0 = 1024 + LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, qwen3State.tempFFN, qwen3State.wrapX, qwen3Config.dim(), + qwen3Config.rmsNormEps(), qwen3State.localSize) + .task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, qwen3State.tempFFN, qwen3Config.dim(), qwen3Config.rmsNormEps()) + .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, qwen3State.wrapXb, qwen3State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + qwen3State.tempFFN); + + unifiedLayer.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, qwen3State.wrapXb, qwen3State.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray(), qwen3Config.dim(), qwen3Config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, qwen3State.wrapHb, qwen3State.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), + qwen3Config.hiddenDim(), qwen3Config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(qwen3State.wrapX); + return unifiedLayer; + } + + /** + * Configure data transfers for first and subsequent layers + */ + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + if (layerIndex == 0) { + // First layer: Transfer temporary buffers and QKV state every execution + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen3State.positionHolder, qwen3State.temp, qwen3State.tempFFN); + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, qwen3State.tempQcur, qwen3State.tempKcur); + // First execution: allocate workspace buffers + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // + context, qwen3State.wrapXb, qwen3State.wrapXb2, // + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, // + qwen3State.wrapKeyCache, qwen3State.wrapValueCache, // + qwen3State.wrapAtt, qwen3State.wrapHb); + } else { + // Subsequent layers: Consume data from previous layer + unifiedLayer.consumeFromDevice(context, qwen3State.wrapXb, qwen3State.wrapXb2, // + qwen3State.wrapQ, qwen3State.wrapK, // + qwen3State.wrapV, qwen3State.wrapKeyCache, // + qwen3State.wrapValueCache, qwen3State.wrapAtt, // + qwen3State.wrapHb, qwen3State.positionHolder); // + + unifiedLayer.consumeFromDevice(qwen3State.tempQcur, qwen3State.tempKcur); + } + return unifiedLayer; + } + +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java new file mode 100644 index 00000000..a2d16830 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java @@ -0,0 +1,174 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0; + +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +public class LlamaQ8_0FFNLayers extends AbstractFFNLayers { + + GridScheduler scheduler; + List ffnLayerTaskGraphs; + + public LlamaQ8_0FFNLayers(String taskGraphName, LlamaState state, LlamaTornadoWeights weights, Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); + ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return null; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + + List setupFFNLayered() { + state.temp.init(0.0f); + state.tempFFN.init(0.0f); + var numLayers = config.numberOfLayers(); + + return IntStream.range(0, numLayers).mapToObj(i -> { + var ffnLayer = setupSingleFFNLayer((LlamaTornadoWeights) weights, config, i); + if (i == numLayers - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + return ffnLayer.snapshot(); + }).toList(); + } + + TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, int layerIndex) { + var layerTaskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + //Copy-in weights per layer for batched-layered layout + weights.rms_att_weightLayered[layerIndex].asFloatArray(), weights.wqLayered[layerIndex].getQuants(), weights.wqLayered[layerIndex].getScales(), weights.wkLayered[layerIndex].getQuants(), + weights.wkLayered[layerIndex].getScales(), weights.wvLayered[layerIndex].getQuants(), weights.wvLayered[layerIndex].getScales(), weights.woLayered[layerIndex].getQuants(), + weights.woLayered[layerIndex].getScales(), weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), weights.w1Layered[layerIndex].getQuants(), weights.w1Layered[layerIndex].getScales(), + weights.w2Layered[layerIndex].getQuants(), weights.w2Layered[layerIndex].getScales(), weights.w3Layered[layerIndex].getQuants(), weights.w3Layered[layerIndex].getScales()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + unifiedLayer.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp, + config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) + .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].getQuants(), + weights.wqLayered[layerIndex].getScales(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapK, weights.wkLayered[layerIndex].getQuants(), + weights.wkLayered[layerIndex].getScales(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapV, weights.wvLayered[layerIndex].getQuants(), + weights.wvLayered[layerIndex].getScales(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("rope", TransformerComputeKernelsLayered::ropeRotation, context, state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(), config.headSize()) + .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), + layerIndex, config.contextLength()); + configureAttention(unifiedLayer, layerIndex); + unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].getQuants(), + weights.woLayered[layerIndex].getScales(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, + config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) + .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].getQuants(), + weights.w1Layered[layerIndex].getScales(), weights.w3Layered[layerIndex].getQuants(), weights.w3Layered[layerIndex].getScales(), config.dim(), config.hiddenDim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].getQuants(), + weights.w2Layered[layerIndex].getScales(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX); + return unifiedLayer; + } + + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + // First layer: Transfer initial data to device (one-time transfer) + if (layerIndex == 0) { + // Transfer all attention-related data: query, key, value matrices and their caches + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); // + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // + context, state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb); // + } else { + // Subsequent layers: Consume data already on device from previous layer + unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb, // + state.positionHolder // + ); + } + return unifiedLayer; + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128); + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configKvDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configKvDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128); + + for (int i = 0; i < config.numberOfLayers(); i++) { + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + } + return tornadoForwardScheduler; + } + + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex) { + if (schedulerType == SchedulerType.NVIDIA) { + return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, config.contextLength()); + } else { + return unifiedLayer.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsParallel, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), config.contextLength(), + state.positionHolder, state.wrapAtt, layerIndex, config.contextLength()); + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java new file mode 100644 index 00000000..75f81d92 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsQ8_0Layer.java @@ -0,0 +1,90 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractLayer; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.WorkerGrid1D; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.SequencedCollection; + +public class LogitsQ8_0Layer extends AbstractLayer { + + private String lastTaskGraphID; + private TaskGraph logitsTaskGraph; + private ImmutableTaskGraph immutableLogitsGraph; + private GridScheduler scheduler; + private SchedulerType schedulerType; + + public LogitsQ8_0Layer(String taskGraphName, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config); + this.lastTaskGraphID = lastTaskGraphID; + state.tempLogits.init(0.0f); + var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsQ8_0Layer", "TornadoTensor"); + this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, config); + this.schedulerType = schedulerType; + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid logitsRMS; + if (weights instanceof Qwen2TornadoWeights) { + logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 32); + } else { + logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + } + + var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; + WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); + vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); + + tornadoForwardScheduler.addWorkerGrid("logits.projection", vocabWorker); + tornadoForwardScheduler.addWorkerGrid("logits.reductionsOneBlockLogits", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); + return tornadoForwardScheduler; + } + + private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { + TaskGraph logits = new TaskGraph("logits"); + logits.consumeFromDevice(lastTaskGraphID, state.wrapX).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits) + .transferToDevice(DataTransferMode.FIRST_EXECUTION, context, state.wrapLogits, weights.wclsByteArray.getQuants(), weights.wclsByteArray.getScales(), + weights.rms_final_weight_as_floatArray) + .task("reductionsOneBlockLogits", TransformerComputeKernels::reductionOneBlockWithLayer, context, state.tempLogits, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize); + if (schedulerType == SchedulerType.NON_NVIDIA) { + logits.task("reductionFinalNormalizationLogits", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits, config.dim(), config.rmsNormEps()); + } + logits.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX, weights.rms_final_weight_as_floatArray.asFloatArray(), state.tempLogits) + .task("projection", TransformerComputeKernelsLayered::matrixVectorGeneric, // + context, state.wrapX, state.wrapLogits, weights.wclsByteArray.getQuants(), weights.wclsByteArray.getScales(), // + config.dim(), config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS) // + .transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); + return logits; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return logitsTaskGraph; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return immutableLogitsGraph; + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java new file mode 100644 index 00000000..d4328a1d --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -0,0 +1,317 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0; + +import org.beehive.gpullama3.inference.state.Phi3State; +import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; +import org.beehive.gpullama3.model.phi3.Phi3Configuration; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.ArrayList; +import java.util.List; + +/** + * Phi3Q8_0FFNLayers: Q8_0-quantized FFN layers for Phi3 with Group Query Attention (GQA) support. + * + * Key Differences from Phi3FP16FFNLayers: + * - Uses Q8_0-quantized weights (getQuants() and getScales()) + * - Same attention and RoPE kernels as FP16 version + * - 8-bit integer computations with dequantization + * - 2x memory compression vs FP16 + * - Same combined QKV and gate/up FFN structure + * + * Works directly with Phi3State to access and mutate Phi3-specific state fields. + */ +public class Phi3Q8_0FFNLayers extends AbstractFFNLayers { + + TaskGraph ffnLayerTaskGraph; + GridScheduler scheduler; + List ffnLayerTaskGraphs; + + // Typed references to Phi3-specific state and config + private final Phi3State phi3State; + private final Phi3Configuration phi3Config; + + // Phi3-specific dimension for combined QKV buffer + private final int opSize; + + public Phi3Q8_0FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeights weights, Phi3Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); + this.phi3State = state; + this.phi3Config = config; + this.opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); + ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 128); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + final int opSize = config.dim() + 2 * (config.numberOfKeyValueHeads() * config.headSize()); + + int qkvmatmulDimRowMajorGlobal = opSize * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid qkvDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(qkvmatmulDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int wgetUPDimRowMajor = 2 * config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid wgetHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(wgetUPDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(config.dim(), 128); + WorkerGrid splitGateUpSiLUWorker = WorkerGridFactory.genericWorker(config.hiddenDim(), 128); + WorkerGrid splitQKVWorker = WorkerGridFactory.genericWorker(opSize, 128); + for (int i = 0; i < config.numberOfLayers(); i++) { + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkvmatmul", qkvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".wDown", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".wGateUp", wgetHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".gateUpSiLU", splitGateUpSiLUWorker); + } + return tornadoForwardScheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return ffnLayerTaskGraph; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + /** + * Setup all FFN layers for all transformer layers + */ + List setupFFNLayered() { + List ffnGraphs = new ArrayList<>(); + + // Initialize buffers using Phi3State directly + phi3State.temp.init(0.0f); + phi3State.tempFFN.init(0.0f); + + for (int layerIndex = 0; layerIndex < phi3Config.numberOfLayers(); layerIndex++) { + TaskGraph ffnLayer = setupSinglePhi3Q8_0FFNLayer((Phi3TornadoWeights) weights, layerIndex); + if (layerIndex == phi3Config.numberOfLayers() - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + ffnGraphs.add(ffnLayer.snapshot()); + } + + return ffnGraphs; + } + + /** + * Setup a single transformer layer for Phi3 with Q8_0 quantization, combined QKV and gate/up FFN + */ + TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex) { + + TaskGraph unifiedLayer = new TaskGraph("layer_" + layerIndex); + unifiedLayer.consumeFromDevice(phi3State.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + // Copy-in quantized weights per layer + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqkvLayered[layerIndex].getQuants(), + weights.wqkvLayered[layerIndex].getScales(), + weights.woLayered[layerIndex].getQuants(), + weights.woLayered[layerIndex].getScales(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.wUpLayered[layerIndex].getQuants(), + weights.wUpLayered[layerIndex].getScales(), + weights.wDownLayered[layerIndex].getQuants(), + weights.wDownLayered[layerIndex].getScales() + ); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + // RMSNorm for attention input + unifiedLayer.task("reductionsOneBlock", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, + phi3State.temp, + phi3State.wrapX, + phi3Config.dim(), + phi3Config.rmsNormEps(), + phi3State.localSize) + .task("mapContext", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, + phi3State.wrapXb, + phi3State.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + phi3State.temp); + + // Combined QKV projection (quantized) + unifiedLayer.task("qkvmatmul", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, + phi3State.wrapXb, + phi3State.wrapQkv, + weights.wqkvLayered[layerIndex].getQuants(), + weights.wqkvLayered[layerIndex].getScales(), + phi3Config.dim(), + opSize, + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("splitQKV", + TransformerComputeKernelsLayered::splitQKV, + phi3State.wrapQkv, + phi3State.wrapQ, + phi3State.wrapK, + phi3State.wrapV, + phi3Config.dim(), + phi3Config.headSize() * phi3Config.numberOfKeyValueHeads()); + + // RoPE rotation (Phi3-specific kernel) + unifiedLayer.task("rope", + TransformerComputeKernelsLayered::ropeRotationPhi3, + context, + phi3State.positionHolder, + phi3State.wrapQ, + phi3State.wrapK, + phi3Config.kvDim(), + phi3Config.headSize()); + + // Copy to caches + unifiedLayer.task("copyToCaches", + TransformerComputeKernelsLayered::copyToCache, + phi3State.wrapKeyCache, + phi3State.wrapK, + phi3State.wrapValueCache, + phi3State.wrapV, + phi3State.positionHolder, + phi3Config.kvDim(), + layerIndex, + phi3Config.contextLength()); + + // Parallel attention + unifiedLayer.task("parallel-attention", + TransformerComputeKernelsLayered::processHeadsFlashAttention, + context, + phi3State.wrapQ, + phi3State.wrapKeyCache, + phi3State.wrapValueCache, + phi3State.wrapXb, + phi3Config.numberOfHeads(), + phi3Config.headSize(), + phi3Config.kvDim(), + phi3Config.kvMul(), + phi3State.positionHolder, + layerIndex, + phi3Config.contextLength()); + + // Output projection (quantized) + unifiedLayer.task("matmul1", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, + phi3State.wrapXb, + phi3State.wrapX, + weights.woLayered[layerIndex].getQuants(), + weights.woLayered[layerIndex].getScales(), + phi3Config.dim(), + phi3Config.dim(), + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // FFN section: RMSNorm + unifiedLayer.task("reductionsOneBlockFFN", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, + phi3State.tempFFN, + phi3State.wrapX, + phi3Config.dim(), + phi3Config.rmsNormEps(), + phi3State.localSize) + .task("mapContextFFN", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, + phi3State.wrapXb, + phi3State.wrapX, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + phi3State.tempFFN); + + // FFN: combined Up and Gate projection (outputs 2 * hiddenDim, quantized) + unifiedLayer.task("wGateUp", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, + phi3State.wrapXb, + phi3State.wrapHb, + weights.wUpLayered[layerIndex].getQuants(), + weights.wUpLayered[layerIndex].getScales(), + phi3Config.dim(), + 2 * phi3Config.hiddenDim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("gateUpSiLU", + TransformerComputeKernelsLayered::splitGateUpAndSiLU, + phi3State.wrapHb, + phi3State.wrapHbG, + phi3State.wrapHbU, + phi3Config.hiddenDim()); + + // FFN: Down projection with residual (quantized) + unifiedLayer.task("wDown", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, + phi3State.wrapHbU, + phi3State.wrapX, + weights.wDownLayered[layerIndex].getQuants(), + weights.wDownLayered[layerIndex].getScales(), + phi3Config.hiddenDim(), + phi3Config.dim(), + LOCAL_WORK_GROUP_SIZE_ALLOC) + .persistOnDevice( + phi3State.wrapX + ); + return unifiedLayer; + } + + /** + * Configure data transfers for first and subsequent layers + */ + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + // First layer: Transfer initial data to device (one-time transfer) + if (layerIndex == 0) { + // Transfer all attention-related data: query, key, value matrices and their caches + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.positionHolder, state.temp, state.tempFFN); // + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // + context, state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb, // + phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); // + } else { + // Subsequent layers: Consume data already on device from previous layer + unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb, // + state.positionHolder, // / + phi3State.wrapHbG, phi3State.wrapHbU, phi3State.wrapQkv); + } + return unifiedLayer; + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java new file mode 100644 index 00000000..b2d8d773 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java @@ -0,0 +1,252 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0; + +import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; +import org.beehive.gpullama3.tornadovm.kernels.Qwen2Kernels; +import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.WorkerGrid1D; +import uk.ac.manchester.tornado.api.WorkerGrid2D; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.ArrayList; +import java.util.List; + +/** + * Qwen2Q8_0FFNLayers: Q8_0-quantized FFN layers for Qwen2 with Group Query Attention (GQA) support. + * + * Key Differences from Qwen2FP16FFNLayers: + * - Uses Q8_0-quantized weights (getQuants() and getScales()) + * - Same attention and RoPE kernels as FP16 version + * - 8-bit integer computations with dequantization + * - 2x memory compression vs FP16 + * - Includes bias terms for Q, K, V projections + * + * Works directly with Qwen2State to access and mutate Qwen2-specific state fields. + */ +public class Qwen2Q8_0FFNLayers extends AbstractFFNLayers { + + TaskGraph ffnLayerTaskGraph; + GridScheduler scheduler; + List ffnLayerTaskGraphs; + + // Typed references to Qwen2-specific state and config + private final Qwen2State qwen2State; + private final Qwen2Configuration qwen2Config; + + public Qwen2Q8_0FFNLayers(String taskGraphName, Qwen2State state, Qwen2TornadoWeights weights, Qwen2Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); + this.qwen2State = state; + this.qwen2Config = config; + ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + int h = config.numberOfHeads(); + int ic = config.headSize() / 2; + WorkerGrid ropeWorker = new WorkerGrid2D(h, ic); + ropeWorker.setGlobalWork(h, ic, 1); + ropeWorker.setLocalWork(1, 1, 1); + + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = new WorkerGrid1D(configDimRowMajorGlobal); + configDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); + + int configKvDimRowMajorGlobal = config.kvDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configKvDimRowMajorGlobalWorker = new WorkerGrid1D(configKvDimRowMajorGlobal); + configKvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); + + WorkerGrid qBiasWorker = new WorkerGrid1D(config.dim()); + qBiasWorker.setGlobalWork(config.dim(), 1, 1); + qBiasWorker.setLocalWork(config.dim() / 8, 1, 1); + WorkerGrid kvBiasWorker = new WorkerGrid1D(config.kvDim()); + kvBiasWorker.setGlobalWork(config.kvDim(), 1, 1); + kvBiasWorker.setLocalWork(32, 1, 1); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = new WorkerGrid1D(configHiddenDimRowMajor); + configHiddenDimRowMajorWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1); + + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 32); + + int optimalLocalSize = Math.min(config.headSize(), 64); // Start with 64 threads per head + if (config.headSize() % optimalLocalSize != 0) { + // Find largest divisor of headSize <= 64 + for (int size = 64; size >= 1; size--) { + if (config.headSize() % size == 0) { + optimalLocalSize = size; + break; + } + } + } + + WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads()); + parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * optimalLocalSize, 1, 1); + parallelAttentionWorker.setLocalWork(optimalLocalSize, 1, 1); + + WorkerGrid copyToCachesWorker = new WorkerGrid1D(config.kvDim()); + copyToCachesWorker.setGlobalWork(config.kvDim(), 1, 1); + copyToCachesWorker.setLocalWork(32, 1, 1); // Set local work size to 32 (for copying to caches) + + // Map workers to tasks + for (int i = 0; i < config.numberOfLayers(); i++) { + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", configKvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", configKvDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qbias", qBiasWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kbias", kvBiasWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vbias", kvBiasWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + } + return tornadoForwardScheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return ffnLayerTaskGraph; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + /** + * Setup all FFN layers for all transformer layers + */ + List setupFFNLayered() { + List ffnGraphs = new ArrayList<>(); + qwen2State.temp.init(0.0f); + qwen2State.tempFFN.init(0.0f); + + for (int layerIndex = 0; layerIndex < qwen2Config.numberOfLayers(); layerIndex++) { + TaskGraph ffnLayer = setupSingleQwen2Q8_0FFNLayer((Qwen2TornadoWeights) weights, layerIndex); + if (layerIndex == qwen2Config.numberOfLayers() - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + ffnGraphs.add(ffnLayer.snapshot()); + } + return ffnGraphs; + } + + /** + * Setup a single transformer layer for Qwen2 with Q8_0 quantization and GQA + */ + TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerIndex) { + TaskGraph unifiedLayer = new TaskGraph("layer_" + layerIndex); + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + //Copy-in weights per layer for batched-layered layout + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].getScales(), + weights.wqLayered[layerIndex].getQuants(), + weights.wkLayered[layerIndex].getScales(), + weights.wkLayered[layerIndex].getQuants(), + weights.wvLayered[layerIndex].getScales(), + weights.wvLayered[layerIndex].getQuants(), + weights.woLayered[layerIndex].getScales(), + weights.woLayered[layerIndex].getQuants(), + weights.q_biasLayered[layerIndex].asFloatArray(), + weights.k_biasLayered[layerIndex].asFloatArray(), + weights.v_biasLayered[layerIndex].asFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].getScales(), + weights.w1Layered[layerIndex].getQuants(), + weights.w2Layered[layerIndex].getScales(), + weights.w2Layered[layerIndex].getQuants(), + weights.w3Layered[layerIndex].getScales(), + weights.w3Layered[layerIndex].getQuants() + ); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + unifiedLayer.task("reductionsOneBlock" , TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, + state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) + .task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, + state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp) + .task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, + state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].getQuants(), weights.wqLayered[layerIndex].getScales(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, + state.wrapXb, state.wrapK, weights.wkLayered[layerIndex].getQuants(), weights.wkLayered[layerIndex].getScales(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("vmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, + state.wrapXb, state.wrapV, weights.wvLayered[layerIndex].getQuants(), weights.wvLayered[layerIndex].getScales(), config.dim(), config.kvDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("qbias", TransformerComputeKernelsLayered::addInPlace, state.wrapQ, weights.q_biasLayered[layerIndex].asFloatArray(), config.dim()) + .task("kbias", TransformerComputeKernelsLayered::addInPlace, state.wrapK, weights.k_biasLayered[layerIndex].asFloatArray(), config.kvDim()) + .task("vbias", TransformerComputeKernelsLayered::addInPlace, state.wrapV, weights.v_biasLayered[layerIndex].asFloatArray(), config.kvDim()) + .task("rope", Qwen3Kernels::ropeRotation,context, state.positionHolder, state.wrapQ, state.wrapK, config.numberOfKeyValueHeads(), + config.headSize()) + .task("copyToCaches", TransformerComputeKernelsLayered::copyToCache, + state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength()) + .task("parallel-attention", Qwen2Kernels::processHeadsFlashAttention, context, + state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb, + config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(), + state.positionHolder, layerIndex, config.contextLength()) + .task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, + state.wrapXb, state.wrapX, weights.woLayered[layerIndex].getQuants(), weights.woLayered[layerIndex].getScales(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, + state.wrapX, config.dim(), config.rmsNormEps(), state.localSize) + .task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, + state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN) + .task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, + state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].getQuants(), weights.w1Layered[layerIndex].getScales(), weights.w3Layered[layerIndex].getQuants(), weights.w3Layered[layerIndex].getScales(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, + state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].getQuants(), weights.w2Layered[layerIndex].getScales(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .persistOnDevice( + state.wrapX + ); + return unifiedLayer; + + } + + /** + * Configure data transfers for first and subsequent layers + */ + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + if (layerIndex == 0) { + // First layer: Transfer temporary buffers and QKV state every execution + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + qwen2State.positionHolder, qwen2State.temp, qwen2State.tempFFN); + // First execution: allocate workspace buffers + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, qwen2State.wrapXb, qwen2State.wrapXb2, + qwen2State.wrapQ, qwen2State.wrapK, qwen2State.wrapV, + qwen2State.wrapKeyCache, qwen2State.wrapValueCache, + qwen2State.wrapAtt, qwen2State.wrapHb); + } else { + // Subsequent layers: Consume data from previous layer + unifiedLayer.consumeFromDevice(context, qwen2State.wrapXb, qwen2State.wrapXb2, + qwen2State.wrapQ, qwen2State.wrapK, qwen2State.wrapV, + qwen2State.wrapKeyCache, qwen2State.wrapValueCache, + qwen2State.wrapAtt, qwen2State.wrapHb, qwen2State.positionHolder); + } + return unifiedLayer; + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java new file mode 100644 index 00000000..ba090bf5 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java @@ -0,0 +1,319 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.ArrayList; +import java.util.List; + +/** + * Qwen3Q8_0FFNLayers: Q8_0-quantized FFN layers for Qwen3 with Group Query Attention (GQA) support. + * + * Key Differences from Qwen3FP16FFNLayers: + * - Uses Q8_0-quantized weights (getQuants() and getScales()) + * - Same Qwen3Kernels for RMSNorm and RoPE + * - 8-bit integer computations with dequantization + * - 2x memory compression vs FP16 + * + * Works directly with Qwen3State to access and mutate Qwen3-specific state fields + * like tempQcur and tempKcur. + */ +public class Qwen3Q8_0FFNLayers extends AbstractFFNLayers { + + String lastTaskGraphID; + TaskGraph ffnLayerTaskGraph; + GridScheduler scheduler; + List ffnLayerTaskGraphs; + + // Typed references to Qwen3-specific state and config + private final Qwen3State qwen3State; + private final Qwen3Configuration qwen3Config; + + // Qwen3-specific GQA parameters + private final int nHeadKv; + private final int nEmbdHeadK; + private final int nEmbdHeadV; + private final int nEmbdVGqa; + private final int nEmbdHead; + private final int nEmbdGqa; + private final int gqa; + + public Qwen3Q8_0FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWeights weights, Qwen3Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); + this.qwen3State = state; + this.qwen3Config = config; + this.nHeadKv = config.numberOfKeyValueHeads(); + this.nEmbdHeadK = config.numberOfHeadsKey(); + this.nEmbdHeadV = config.numberOfHeadsValue(); + this.nEmbdVGqa = nEmbdHeadV * nHeadKv; + this.nEmbdHead = nEmbdHeadV; + this.nEmbdGqa = nEmbdVGqa; + this.gqa = config.numberOfHeads() / config.numberOfKeyValueHeads(); + ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), state.localSize); + + int matmulQGlobal = nEmbdHeadK * config.numberOfHeads() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid matmulQRowMajorWorker = WorkerGridFactory.genericWorker(matmulQGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int matmulKVGlobal = nEmbdGqa * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid matmulKVRowMajorWorker = WorkerGridFactory.genericWorker(matmulKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid qCurWorker = WorkerGridFactory.genericWorker(config.numberOfHeads() * nEmbdHead, nEmbdHead); + WorkerGrid kCurWorker = WorkerGridFactory.genericWorker(config.numberOfKeyValueHeads() * nEmbdHead, nEmbdHead); + + int h = config.numberOfHeads(); + int ic = nEmbdHead / 2; + WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(h, nEmbdHead); + WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(nEmbdGqa, 128); + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), nEmbdHead); + + int matmul1Global = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid matmul1Worker = WorkerGridFactory.genericWorker(matmul1Global, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int fusedFFNW1W3Global = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedFFNW1W3Worker = WorkerGridFactory.genericWorker(fusedFFNW1W3Global, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int projectionTwoGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid projectionTwoWorker = WorkerGridFactory.genericWorker(projectionTwoGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + for (int i = 0; i < config.numberOfLayers(); i++) { + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qmatmul", matmulQRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".kmatmul", matmulKVRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".vmatmul", matmulKVRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Qcur", qCurWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Qcur", qCurWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Kcur", kCurWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Kcur", kCurWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ropeRotation", ropeWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", fusedFFNW1W3Worker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", projectionTwoWorker); + } + return tornadoForwardScheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return ffnLayerTaskGraph; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + /** + * Setup all FFN layers for all transformer layers + */ + List setupFFNLayered() { + List ffnGraphs = new ArrayList<>(); + qwen3State.temp.init(0.0f); + qwen3State.tempFFN.init(0.0f); + qwen3State.tempQcur.init(0.0f); + qwen3State.tempKcur.init(0.0f); + + for (int layerIndex = 0; layerIndex < qwen3Config.numberOfLayers(); layerIndex++) { + TaskGraph ffnLayer = setupSingleQwen3FFNLayer((Qwen3TornadoWeights) weights, layerIndex); + if (layerIndex == qwen3Config.numberOfLayers() - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + ffnGraphs.add(ffnLayer.snapshot()); + } + return ffnGraphs; + } + + /** + * Setup a single transformer layer for Qwen3 with GQA (Q8_0 quantized) + */ + TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) { + + var unifiedLayerName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(unifiedLayerName); + unifiedLayer.consumeFromDevice(qwen3State.wrapX); + // Transfer Q8_0 weights for this layer (quants and scales) + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), // + weights.wqLayered[layerIndex].getQuants(), // + weights.wqLayered[layerIndex].getScales(), // + weights.wkLayered[layerIndex].getQuants(), // + weights.wkLayered[layerIndex].getScales(), // + weights.wvLayered[layerIndex].getQuants(), // + weights.wvLayered[layerIndex].getScales(),// + weights.woLayered[layerIndex].getQuants(),// + weights.woLayered[layerIndex].getScales(),// + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), // + weights.rms_att_QNormLayered[layerIndex].asFloatArray(),// + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // + weights.w1Layered[layerIndex].getQuants(), // + weights.w1Layered[layerIndex].getScales(), // + weights.w2Layered[layerIndex].getQuants(), // + weights.w2Layered[layerIndex].getScales(), // + weights.w3Layered[layerIndex].getQuants(), // + weights.w3Layered[layerIndex].getScales()); // + + // Configure layer data transfers (EVERY_EXECUTION and device persistence) + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + + // RMS norm for attention input + unifiedLayer.task("reductionsOneBlock", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, qwen3State.temp, qwen3State.wrapX, config.dim(), config.rmsNormEps(), qwen3State.localSize) + .task("mapContext", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, qwen3State.wrapXb, qwen3State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), qwen3State.temp); + + // QKV projections with Qwen3 GQA dimensions + // Q8_0 weights pass both quants and scales + int qDim0 = nEmbdHeadK * config.numberOfHeads(); // Query dimension + int kvDim0 = nEmbdGqa; // KV dimension (smaller due to GQA) + int qkvDim1 = config.dim(); // Input dimension + + unifiedLayer.task("qmatmul", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, qwen3State.wrapXb, qwen3State.wrapQ, + weights.wqLayered[layerIndex].getQuants(), weights.wqLayered[layerIndex].getScales(), + qkvDim1, qDim0, LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("kmatmul", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, qwen3State.wrapXb, qwen3State.wrapK, + weights.wkLayered[layerIndex].getQuants(), weights.wkLayered[layerIndex].getScales(), + qkvDim1, kvDim0, LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("vmatmul", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, qwen3State.wrapXb, qwen3State.wrapV, + weights.wvLayered[layerIndex].getQuants(), weights.wvLayered[layerIndex].getScales(), + qkvDim1, kvDim0, LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Qcur: RMS norm with parallel offset for Query + Qwen3State qwen3State = (Qwen3State) state; + unifiedLayer.task("rmsnormReduction_Qcur", + Qwen3Kernels::rmsnormWithParallelOffset, + context, qwen3State.tempQcur, qwen3State.wrapQ, qwen3State.localSize, nEmbdHead, config.rmsNormEps()) + .task("rmsnormMapIndexInPlace_Qcur", + Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset, + context, qwen3State.wrapQ, weights.rms_att_QNormLayered[layerIndex].asFloatArray(), nEmbdHead, qwen3State.tempQcur); + + // Kcur: RMS norm with parallel offset for Key + unifiedLayer.task("rmsnormReduction_Kcur", + Qwen3Kernels::rmsnormWithParallelOffset, + context, qwen3State.tempKcur, qwen3State.wrapK, qwen3State.localSize, nEmbdHead, config.rmsNormEps()) + .task("rmsnormMapIndexInPlace_Kcur", + Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset, + context, qwen3State.wrapK, weights.rms_att_KNormLayered[layerIndex].asFloatArray(), nEmbdHead, qwen3State.tempKcur); + + // RoPE rotation (Qwen3 variant) + unifiedLayer.task("ropeRotation", + Qwen3Kernels::ropeRotation, + context, qwen3State.positionHolder, qwen3State.wrapQ, qwen3State.wrapK, + config.numberOfKeyValueHeads(), nEmbdHead); + + // Copy to KV cache + unifiedLayer.task("copyToCaches", + TransformerComputeKernelsLayered::copyToCache, + qwen3State.wrapKeyCache, qwen3State.wrapK, qwen3State.wrapValueCache, qwen3State.wrapV, + qwen3State.positionHolder, nEmbdGqa, layerIndex, config.contextLength()); + + // Parallel attention (with GQA support) + unifiedLayer.task("parallel-attention", + TransformerComputeKernelsLayered::processHeadsFlashAttentionOpt, + context, qwen3State.wrapQ, qwen3State.wrapKeyCache, qwen3State.wrapValueCache, qwen3State.wrapXb, + config.numberOfHeads(), nEmbdHead, nEmbdGqa, gqa, qwen3State.positionHolder, layerIndex, config.contextLength()); + + // Output projection (Q8_0 weights) + unifiedLayer.task("matmul1", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, qwen3State.wrapXb, qwen3State.wrapX, + weights.woLayered[layerIndex].getQuants(), weights.woLayered[layerIndex].getScales(), + qDim0, config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC); + + // ========== FEED-FORWARD BLOCK ========== + + // RMS norm for FFN input + unifiedLayer.task("reductionsOneBlockFFN", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, qwen3State.tempFFN, qwen3State.wrapX, config.dim(), config.rmsNormEps(), qwen3State.localSize) + .task("mapContextFFN", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, qwen3State.wrapXb, qwen3State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), qwen3State.tempFFN); + + // Fused FFN: w1(x) ⊗ w3(x) with SiLU activation (Q8_0 weights) + unifiedLayer.task("fused_ffn_w1_w3", + TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, + context, qwen3State.wrapXb, qwen3State.wrapHb, + weights.w1Layered[layerIndex].getQuants(), weights.w1Layered[layerIndex].getScales(), + weights.w3Layered[layerIndex].getQuants(), weights.w3Layered[layerIndex].getScales(), + config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .task("projectionTwo", + TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, + context, qwen3State.wrapHb, qwen3State.wrapX, + weights.w2Layered[layerIndex].getQuants(), weights.w2Layered[layerIndex].getScales(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC) + .persistOnDevice(state.wrapX); + + return unifiedLayer; + } + + /** + * Configure data transfers for first and subsequent layers + */ + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + if (layerIndex == 0) { + // First layer: Transfer temporary buffers and QKV state every execution + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + qwen3State.positionHolder, qwen3State.temp, qwen3State.tempFFN); + + Qwen3State qwen3State = (Qwen3State) state; + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + qwen3State.tempQcur, qwen3State.tempKcur); + + // First execution: allocate workspace buffers + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // + context, qwen3State.wrapXb, qwen3State.wrapXb2, // + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, // + qwen3State.wrapKeyCache, qwen3State.wrapValueCache, // + qwen3State.wrapAtt, qwen3State.wrapHb); // + } else { + // Subsequent layers: Consume data from previous layer + unifiedLayer.consumeFromDevice(context, qwen3State.wrapXb, qwen3State.wrapXb2, // + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, // + qwen3State.wrapKeyCache, qwen3State.wrapValueCache, // + qwen3State.wrapAtt, qwen3State.wrapHb, qwen3State.positionHolder); // + + Qwen3State qwen3State = (Qwen3State) state; + unifiedLayer.consumeFromDevice(qwen3State.tempQcur, qwen3State.tempKcur); // + } + return unifiedLayer; + } + +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/FloatArrayUtils.java b/src/main/java/org/beehive/gpullama3/tornadovm/utils/FloatArrayUtils.java similarity index 99% rename from src/main/java/org/beehive/gpullama3/tornadovm/FloatArrayUtils.java rename to src/main/java/org/beehive/gpullama3/tornadovm/utils/FloatArrayUtils.java index 2e395339..23ef13cc 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/FloatArrayUtils.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/utils/FloatArrayUtils.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.tornadovm; +package org.beehive.gpullama3.tornadovm.utils; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.api.math.TornadoMath;