From 8c43642a8cf3eaee10fae7ae6ae754bec4502712 Mon Sep 17 00:00:00 2001 From: hkxhrwang Date: Sun, 12 Jan 2025 22:12:26 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E4=BD=BF=E7=94=A8=20Colab=20=E5=88=9B?= =?UTF-8?q?=E5=BB=BA=E8=80=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Deconstructing-QWen2-from-Ground-Up.ipynb | 1497 +++++++++++++++++++++ 1 file changed, 1497 insertions(+) create mode 100644 Deconstructing-QWen2-from-Ground-Up.ipynb diff --git a/Deconstructing-QWen2-from-Ground-Up.ipynb b/Deconstructing-QWen2-from-Ground-Up.ipynb new file mode 100644 index 00000000..d4a037fc --- /dev/null +++ b/Deconstructing-QWen2-from-Ground-Up.ipynb @@ -0,0 +1,1497 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e2c7e6cd-6087-42eb-add7-0143210de63e", + "metadata": { + "id": "e2c7e6cd-6087-42eb-add7-0143210de63e" + }, + "source": [ + "# Deconstructing QWen2 from the Ground Up\n", + "In this project, I will demonstrate how to deconstruct QWen2 from scratch. Specifically, I will explore how to complete a Chinese proverb which is generating a \"退\" from the input input_text=\"学习如逆水行舟,不进则\". I hope this project will help everyone gain a better understanding of the structure of QWen2, and also want to take this opportunity to promote China's LLM.\n", + "\n", + "Here is the offical link to download the weights: **https://www.modelscope.cn/models/qwen/Qwen2-7B/files**\n", + "\n", + "
\n", + " \n", + "
" + ] + }, + { + "cell_type": "code", + "source": [ + "!git lfs install\n", + "!git clone https://huggingface.co/Qwen/Qwen2.5-3B" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "MMTYx-EKr_Fv", + "outputId": "9d9b51c3-03cd-40ec-ebf6-5251dcdefe47" + }, + "id": "MMTYx-EKr_Fv", + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Git LFS initialized.\n", + "Cloning into 'Qwen2.5-3B'...\n", + "remote: Enumerating objects: 50, done.\u001b[K\n", + "remote: Counting objects: 100% (47/47), done.\u001b[K\n", + "remote: Compressing objects: 100% (47/47), done.\u001b[K\n", + "remote: Total 50 (delta 21), reused 0 (delta 0), pack-reused 3 (from 1)\u001b[K\n", + "Unpacking objects: 100% (50/50), 3.61 MiB | 5.17 MiB/s, done.\n", + "Filtering content: 100% (2/2), 5.74 GiB | 38.75 MiB/s, done.\n" + ] + } + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "00d46041-7dbd-4d7c-bb14-2d7bbd7d9a48", + "metadata": { + "id": "00d46041-7dbd-4d7c-bb14-2d7bbd7d9a48" + }, + "outputs": [], + "source": [ + "import torch\n", + "import json\n", + "import matplotlib.pyplot as plt\n", + "import math\n", + "from torch import nn\n", + "import torch\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer" + ] + }, + { + "cell_type": "markdown", + "id": "c3846ceb-a2d3-4104-9b8e-c1c56f6fd115", + "metadata": { + "id": "c3846ceb-a2d3-4104-9b8e-c1c56f6fd115" + }, + "source": [ + "# Tokenizer\n", + "Here, I'm not going to show the principle and implementation of LLM's tokenizer. Andrej Karpathy has provided a one-to-one implementation of GPT4Tokenizer. His code is really easy to understand!!!\n", + "\n", + "
\n", + " \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e3116bf7-4777-4a76-be81-7ff8c52b1cfe", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 49, + "referenced_widgets": [ + "26c723f59a96410489b443cbef0cb4a2", + "5aeeeb91317a4d91b5a4f04a6a92fe99", + "29a762cdc7194ed5a9678abc3f55f496", + "b50c06cfcbb74502a75ddec0e87e7157", + "474d5a23fe984a2fa80e13f0cbd187bb", + "d656cf9b3c414230828e38146096942c", + "e4f07a621e6b4eb29380afcda7af8ba1", + "bdc1353aea5b4a5c86d279efeb483f7f", + "9ae197159ee44b629c2e0b440dae62cb", + "68e131ba08a744f6b085d75bb59c674e", + "f451c98a8f1a4ee99380d30ed185855e" + ] + }, + "id": "e3116bf7-4777-4a76-be81-7ff8c52b1cfe", + "outputId": "71e65e26-9c4f-44e5-8b5b-b19ba0c64ee9" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00\n", + " \n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "40fda03f-4ff6-4729-8f14-cf4ea4af5bea", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "40fda03f-4ff6-4729-8f14-cf4ea4af5bea", + "outputId": "76d5778b-6bdb-409e-f122-ff1d8f7af8ca" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[\n", + " \"model.embed_tokens.weight\",\n", + " \"model.layers.0.self_attn.q_proj.weight\",\n", + " \"model.layers.0.self_attn.q_proj.bias\",\n", + " \"model.layers.0.self_attn.k_proj.weight\",\n", + " \"model.layers.0.self_attn.k_proj.bias\",\n", + " \"model.layers.0.self_attn.v_proj.weight\",\n", + " \"model.layers.0.self_attn.v_proj.bias\",\n", + " \"model.layers.0.self_attn.o_proj.weight\",\n", + " \"model.layers.0.mlp.gate_proj.weight\",\n", + " \"model.layers.0.mlp.up_proj.weight\",\n", + " \"model.layers.0.mlp.down_proj.weight\",\n", + " \"model.layers.0.input_layernorm.weight\",\n", + " \"model.layers.0.post_attention_layernorm.weight\",\n", + " \"model.layers.1.self_attn.q_proj.weight\",\n", + " \"model.layers.1.self_attn.q_proj.bias\",\n", + " \"model.layers.1.self_attn.k_proj.weight\",\n", + " \"model.layers.1.self_attn.k_proj.bias\",\n", + " \"model.layers.1.self_attn.v_proj.weight\",\n", + " \"model.layers.1.self_attn.v_proj.bias\",\n", + " \"model.layers.1.self_attn.o_proj.weight\"\n", + "]\n" + ] + } + ], + "source": [ + "model = model.state_dict()\n", + "print(json.dumps(list(model.keys())[:20], indent=4))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "22de6dab-0872-4adc-bd35-df5d64d21a66", + "metadata": { + "scrolled": true, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "22de6dab-0872-4adc-bd35-df5d64d21a66", + "outputId": "923b5737-2803-410c-b3ef-a781a611e7c6" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'architectures': ['Qwen2ForCausalLM'],\n", + " 'attention_dropout': 0.0,\n", + " 'bos_token_id': 151643,\n", + " 'eos_token_id': 151643,\n", + " 'hidden_act': 'silu',\n", + " 'hidden_size': 2048,\n", + " 'initializer_range': 0.02,\n", + " 'intermediate_size': 11008,\n", + " 'max_position_embeddings': 32768,\n", + " 'max_window_layers': 36,\n", + " 'model_type': 'qwen2',\n", + " 'num_attention_heads': 16,\n", + " 'num_hidden_layers': 36,\n", + " 'num_key_value_heads': 2,\n", + " 'rms_norm_eps': 1e-06,\n", + " 'rope_theta': 1000000.0,\n", + " 'sliding_window': 32768,\n", + " 'tie_word_embeddings': True,\n", + " 'torch_dtype': 'bfloat16',\n", + " 'transformers_version': '4.40.1',\n", + " 'use_cache': True,\n", + " 'use_mrope': False,\n", + " 'use_sliding_window': False,\n", + " 'vocab_size': 151936}" + ] + }, + "metadata": {}, + "execution_count": 7 + } + ], + "source": [ + "with open(\"./Qwen2.5-3B/config.json\", \"r\") as f:\n", + " config = json.load(f)\n", + "config" + ] + }, + { + "cell_type": "markdown", + "id": "7587d42a-73a9-4a64-88b0-52533baacdcf", + "metadata": { + "id": "7587d42a-73a9-4a64-88b0-52533baacdcf" + }, + "source": [ + "## We will use these configs to assemble the QWen2\n", + "1. 28 transformer layers\n", + "2. 28 attention heads\n", + "3. 4 kv heads and so on." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4eeed7fa-af82-469e-b1ad-ba65db385055", + "metadata": { + "id": "4eeed7fa-af82-469e-b1ad-ba65db385055" + }, + "outputs": [], + "source": [ + "dim = config[\"hidden_size\"]\n", + "n_layers = config[\"num_hidden_layers\"]\n", + "n_heads = config[\"num_attention_heads\"]\n", + "n_kv_heads = config[\"num_key_value_heads\"]\n", + "vocab_size = config[\"vocab_size\"]\n", + "norm_eps = config[\"rms_norm_eps\"]\n", + "rope_theta = torch.tensor(config[\"rope_theta\"])" + ] + }, + { + "cell_type": "markdown", + "id": "ea783ba8-5770-4f03-ab8c-9f4dccd4f1a2", + "metadata": { + "id": "ea783ba8-5770-4f03-ab8c-9f4dccd4f1a2" + }, + "source": [ + "## Convert text to tokens\n", + "I'm going to use QWen2's build-in tokenizer to do presentation.\n", + "\n", + "You may be confused why \"学习\" and \",不\" is in one token (consider the principle of bpe). Later, some other Chinese characters maybe represented by two or more tokens like \"炊\".\n", + "
\n", + " \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "13da007c-c5c9-442e-a516-3d62ee2e78a5", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "13da007c-c5c9-442e-a516-3d62ee2e78a5", + "outputId": "476c4367-12f8-4f8c-b10a-6815a7a73dd3" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[100134, 29524, 100531, 52510, 22243, 102748, 3837, 16530, 41299, 46448]\n", + "['学习', '如', '逆', '水', '行', '舟', ',', '不', '进', '则']\n" + ] + } + ], + "source": [ + "prompt = \"学习如逆水行舟,不进则\"\n", + "tokens = tokenizer.encode(prompt)\n", + "print(tokens)\n", + "tokens = torch.tensor(tokens)\n", + "prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens]\n", + "print(prompt_split_as_tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "471ef96d-8fc3-4f65-9c90-9fd9abd6718a", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "471ef96d-8fc3-4f65-9c90-9fd9abd6718a", + "outputId": "75adbaa3-49e7-4912-8ef7-e12cfa6bc717" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([10, 2048])\n" + ] + } + ], + "source": [ + "embedding_layer = torch.nn.Embedding.from_pretrained(model['model.embed_tokens.weight'])\n", + "token_embeddings_unnormalized = embedding_layer(tokens.to(device))\n", + "print(token_embeddings_unnormalized.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "8054bce7-402c-4caf-a199-87d237a34b23", + "metadata": { + "id": "8054bce7-402c-4caf-a199-87d237a34b23" + }, + "source": [ + "## Normalize the embedding using rms normalization\n", + "RMS normalization (Root Mean Square normalization) is used in the embedding layers of Large Language Models (LLMs) for several reasons:\n", + "1. Stabilizing Training\n", + "2. Improving Convergence\n", + "3. Better Generalization\n", + "4. Handling Variability in Embedding Magnitudes\n", + "\n", + "It's worth noting that we need to set a norm_eps to avoid the formula dived by 0.\n", + "\n", + "
\n", + " \n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "64588068-1f7e-46df-97eb-3c837e501c7e", + "metadata": { + "id": "64588068-1f7e-46df-97eb-3c837e501c7e" + }, + "outputs": [], + "source": [ + "def rms_norm(tensor, norm_weights):\n", + " return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights" + ] + }, + { + "cell_type": "markdown", + "id": "01a4c090-46d4-4ed3-a965-fc1ce24c1509", + "metadata": { + "id": "01a4c090-46d4-4ed3-a965-fc1ce24c1509" + }, + "source": [ + "# Build the first transformer layer\n", + "### Normalization\n", + "You can see, after through layer0 from the dict extract from the model.\n", + "\n", + "The output tensor is still shape in [10*3584] but normalized.\n", + "\n", + "
\n", + " \n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "67b76a3f-67dd-4366-a379-c6e5a0405799", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "67b76a3f-67dd-4366-a379-c6e5a0405799", + "outputId": "9843b5bf-70cf-4fe9-ff09-daae850428ac" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([10, 2048])\n" + ] + } + ], + "source": [ + "token_embeddings = rms_norm(token_embeddings_unnormalized, model[\"model.layers.0.input_layernorm.weight\"])\n", + "print(token_embeddings.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "1efc6246-ce93-4ae8-879e-390e04122af0", + "metadata": { + "id": "1efc6246-ce93-4ae8-879e-390e04122af0" + }, + "source": [ + "## Assemble attention from scratch\n", + "Load the attention heads of the first layer of transformer.\n", + "\n", + "
\n", + " \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "da9dbcd2-b516-4e55-b998-712e2abede87", + "metadata": { + "id": "da9dbcd2-b516-4e55-b998-712e2abede87" + }, + "outputs": [], + "source": [ + "q_layer0 = model[\"model.layers.0.self_attn.q_proj.weight\"]\n", + "k_layer0 = model[\"model.layers.0.self_attn.k_proj.weight\"]\n", + "v_layer0 = model[\"model.layers.0.self_attn.v_proj.weight\"]\n", + "o_layer0 = model[\"model.layers.0.self_attn.o_proj.weight\"]\n", + "q_layer0_bias = model['model.layers.0.self_attn.q_proj.bias']\n", + "k_layer0_bias = model['model.layers.0.self_attn.k_proj.bias']\n", + "v_layer0_bias = model['model.layers.0.self_attn.v_proj.bias']" + ] + }, + { + "cell_type": "markdown", + "id": "dc492a89-0e9c-435b-96ca-76a365937d2a", + "metadata": { + "id": "dc492a89-0e9c-435b-96ca-76a365937d2a" + }, + "source": [ + "## Now, we recive the query, key, and value for the token\n", + "Their shape is [10*3584], which 10 is the length of embedding tokens and 3584 is dimension of hidden state." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "a6d3af1d-b3df-420b-8903-0324e70d5d20", + "metadata": { + "id": "a6d3af1d-b3df-420b-8903-0324e70d5d20" + }, + "outputs": [], + "source": [ + "query_states = torch.matmul(token_embeddings, q_layer0.T)+q_layer0_bias\n", + "key_states = torch.matmul(token_embeddings, k_layer0.T)+k_layer0_bias\n", + "value_states = torch.matmul(token_embeddings, v_layer0.T)+v_layer0_bias" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "d7c1baf0-e830-4265-8ff0-8d6181d964e5", + "metadata": { + "id": "d7c1baf0-e830-4265-8ff0-8d6181d964e5" + }, + "outputs": [], + "source": [ + "q_len = len(tokens)\n", + "head_dim = dim//n_heads\n", + "query_states = query_states.view(1, q_len, n_heads, head_dim).transpose(1, 2)\n", + "key_states = key_states.view(1, q_len, n_kv_heads, head_dim).transpose(1, 2)\n", + "value_states = value_states.view(1, q_len, n_kv_heads, head_dim).transpose(1, 2)" + ] + }, + { + "cell_type": "markdown", + "id": "bbd657fd-65b8-4e4b-b902-c33122ff3055", + "metadata": { + "id": "bbd657fd-65b8-4e4b-b902-c33122ff3055" + }, + "source": [ + "## Positioning encoding\n", + "Due to query, key, and value can not represent the position information of tokens. Transformers are designed to handle sequences of data, but unlike recurrent neural networks (RNNs), they do not process the data in a sequential order. Positional encoding addresses this by adding information about the position of each token in the sequence, enabling the model to understand the order and relative position of tokens.\n", + "\n", + "### RoPE\n", + "watch this video (this is what i watched) to understand the math.\n", + "**https://www.youtube.com/watch?v=o29P0Kpobz0&t=530s**\n", + "\n", + "### Here I use the original positional encoding code from QWen2\n", + "Qwen2RotaryEmbedding() is used to generate rotating position encoding to efficiently provide position encoding for input sequences by calculating and caching cosine and sine values." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "965d1db5-5fb2-46f8-a276-1ae5567a90be", + "metadata": { + "id": "965d1db5-5fb2-46f8-a276-1ae5567a90be" + }, + "outputs": [], + "source": [ + "class Qwen2RotaryEmbedding(nn.Module):\n", + " def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n", + " super().__init__()\n", + "\n", + " self.dim = dim\n", + " self.max_position_embeddings = max_position_embeddings\n", + " self.base = base\n", + " inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))\n", + " self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n", + "\n", + " # Build here to make `torch.jit.trace` work.\n", + " self._set_cos_sin_cache(\n", + " seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()\n", + " )\n", + "\n", + " def _set_cos_sin_cache(self, seq_len, device, dtype):\n", + " self.max_seq_len_cached = seq_len\n", + " t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)\n", + "\n", + " freqs = torch.outer(t, self.inv_freq)\n", + " # Different from paper, but it uses a different permutation in order to obtain the same calculation\n", + " emb = torch.cat((freqs, freqs), dim=-1)\n", + " self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n", + " self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n", + "\n", + " def forward(self, x, seq_len=None):\n", + " # x: [bs, num_attention_heads, seq_len, head_size]\n", + " if seq_len > self.max_seq_len_cached:\n", + " self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)\n", + "\n", + " return (\n", + " self.cos_cached[:seq_len].to(dtype=x.dtype),\n", + " self.sin_cached[:seq_len].to(dtype=x.dtype),\n", + " )\n", + "rotary_emb = Qwen2RotaryEmbedding(\n", + " 128,\n", + " max_position_embeddings=131072,\n", + " base=rope_theta,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "23f8e3f1-dd11-4dae-a827-ba67b6f657f1", + "metadata": { + "id": "23f8e3f1-dd11-4dae-a827-ba67b6f657f1" + }, + "source": [ + "## apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1)\n", + "By combining query and key tensors with cosine and sine values, including rotation operations, positional coding information is embedded in the query and key tensors.\n", + "## rotate_half(x)\n", + "This rotation operation allows each element of the vector to be combined with the cosine and sine values of the corresponding position, thereby changing the direction and amplitude of the vector." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "d53e5a0b-aa06-473b-81cb-cbc9892b2574", + "metadata": { + "id": "d53e5a0b-aa06-473b-81cb-cbc9892b2574" + }, + "outputs": [], + "source": [ + "def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):\n", + " \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n", + "\n", + " Args:\n", + " q (`torch.Tensor`): The query tensor.\n", + " k (`torch.Tensor`): The key tensor.\n", + " cos (`torch.Tensor`): The cosine part of the rotary embedding.\n", + " sin (`torch.Tensor`): The sine part of the rotary embedding.\n", + " position_ids (`torch.Tensor`):\n", + " The position indices of the tokens corresponding to the query and key tensors. For example, this can be\n", + " used to pass offsetted position ids when working with a KV-cache.\n", + " unsqueeze_dim (`int`, *optional*, defaults to 1):\n", + " The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n", + " sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n", + " that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n", + " k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n", + " cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n", + " the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n", + " Returns:\n", + " `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n", + " \"\"\"\n", + " cos = cos[position_ids].unsqueeze(unsqueeze_dim)\n", + " sin = sin[position_ids].unsqueeze(unsqueeze_dim)\n", + " q_embed = (q * cos) + (rotate_half(q) * sin)\n", + " k_embed = (k * cos) + (rotate_half(k) * sin)\n", + " return q_embed, k_embed\n", + "\n", + "\n", + "def rotate_half(x):\n", + " \"\"\"Rotates half the hidden dims of the input.\"\"\"\n", + " x1 = x[..., : x.shape[-1] // 2]\n", + " x2 = x[..., x.shape[-1] // 2 :]\n", + " return torch.cat((-x2, x1), dim=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "be770e30-e653-4b4b-84ed-110a19533660", + "metadata": { + "id": "be770e30-e653-4b4b-84ed-110a19533660" + }, + "outputs": [], + "source": [ + "cos, sin = rotary_emb(value_states, seq_len=q_len)\n", + "position_ids = torch.arange(q_len).view(1,q_len)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "b44c88a9-e06b-4a58-a4b7-62df0dd94668", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 279 + }, + "id": "b44c88a9-e06b-4a58-a4b7-62df0dd94668", + "outputId": "8c7ad846-b2f2-4392-97b7-994ae2bcf790" + }, + "outputs": [ + { + "output_type": "error", + "ename": "RuntimeError", + "evalue": "indices should be either on cpu or on the same device as the indexed tensor (cpu)", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mquery_states\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mapply_rotary_pos_emb\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquery_states\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey_states\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msin\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mposition_ids\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m\u001b[0m in \u001b[0;36mapply_rotary_pos_emb\u001b[0;34m(q, k, cos, sin, position_ids, unsqueeze_dim)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;31m`\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;31m`\u001b[0m \u001b[0mcomprising\u001b[0m \u001b[0mof\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mquery\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mkey\u001b[0m \u001b[0mtensors\u001b[0m \u001b[0mrotated\u001b[0m \u001b[0musing\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mRotary\u001b[0m \u001b[0mPosition\u001b[0m \u001b[0mEmbedding\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \"\"\"\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0mcos\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcos\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mposition_ids\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0munsqueeze_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0msin\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msin\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mposition_ids\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0munsqueeze_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0mq_embed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mq\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcos\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mrotate_half\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mq\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0msin\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: indices should be either on cpu or on the same device as the indexed tensor (cpu)" + ] + } + ], + "source": [ + "query_states, key_states = apply_rotary_pos_emb(query_states.cpu(), key_states.cpu(), cos, sin, position_ids.to)\n", + "query_states.to(device)\n", + "key_states.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad4b29e6-d009-4aa0-b968-9d78fd06492c", + "metadata": { + "id": "ad4b29e6-d009-4aa0-b968-9d78fd06492c" + }, + "outputs": [], + "source": [ + "def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n", + " \"\"\"\n", + " This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n", + " num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n", + " \"\"\"\n", + " batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n", + " if n_rep == 1:\n", + " return hidden_states\n", + " hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n", + " return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85b7fb25-cb30-4a13-8da0-1393495b17c1", + "metadata": { + "id": "85b7fb25-cb30-4a13-8da0-1393495b17c1" + }, + "outputs": [], + "source": [ + "key_states = repeat_kv(key_states, n_heads // n_kv_heads)\n", + "value_states = repeat_kv(value_states, n_heads // n_kv_heads)" + ] + }, + { + "cell_type": "markdown", + "id": "79329733-28a8-466d-aa89-4d7e988d1312", + "metadata": { + "id": "79329733-28a8-466d-aa89-4d7e988d1312" + }, + "source": [ + "## Scaled Dot-Product Attention\n", + "This is a core attention mechanism in the Transformer architecture that allows the model to dynamically adjust its focus to different locations based on correlations in the input sequence. Specifically, this function performs dot product attention calculations on query, key, and value tensors.\n", + "\n", + "
\n", + " \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2174fff2-0989-4b2c-bfc9-7e5e37e47db9", + "metadata": { + "id": "2174fff2-0989-4b2c-bfc9-7e5e37e47db9" + }, + "outputs": [], + "source": [ + "attn_output = torch.nn.functional.scaled_dot_product_attention(\n", + " query_states,\n", + " key_states,\n", + " value_states,\n", + " attn_mask=None,\n", + " dropout_p= 0.0,\n", + " # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.\n", + " is_causal= True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c69e00a9-2373-4937-b2d3-ed3eae4481b3", + "metadata": { + "id": "c69e00a9-2373-4937-b2d3-ed3eae4481b3" + }, + "outputs": [], + "source": [ + "attn_output = attn_output.transpose(1, 2).contiguous()\n", + "attn_output = attn_output.view(1, q_len, dim)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d3dd439-a258-48a6-9e1d-13829c94a078", + "metadata": { + "id": "0d3dd439-a258-48a6-9e1d-13829c94a078", + "outputId": "20b052da-792e-4b78-ed48-8a1547834262" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[-1.8568e-01, 1.3149e-01, -1.5167e-01, ..., 2.8487e-02,\n", + " -8.3742e-02, -2.3384e-02],\n", + " [-1.2537e-01, 2.0195e-01, -1.2300e-02, ..., -3.6986e-02,\n", + " -1.8594e-01, 9.9794e-02],\n", + " [-1.4426e-01, 1.5807e-01, -1.7747e-01, ..., -7.1516e-02,\n", + " 7.0311e-02, -1.7331e-01],\n", + " ...,\n", + " [-5.9189e-02, 4.0363e-02, -1.3974e-05, ..., -5.2831e-02,\n", + " -2.0385e-02, 8.6324e-03],\n", + " [ 3.7043e-02, 5.2902e-02, 3.0693e-03, ..., -8.9145e-02,\n", + " -1.0277e-01, 1.0480e-02],\n", + " [-8.8573e-02, 1.8764e-02, -4.4170e-02, ..., 1.4842e-01,\n", + " -9.0892e-02, 5.9852e-02]]])" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_states = torch.matmul(attn_output, o_layer0.T)\n", + "output_states" + ] + }, + { + "cell_type": "markdown", + "id": "231cd220-64a2-408f-a644-87a717ee46df", + "metadata": { + "id": "231cd220-64a2-408f-a644-87a717ee46df" + }, + "source": [ + "## Residual neural networks\n", + "The problem of vanishing gradient and exploding gradient can be solved by introducing residual block.\n", + "
\n", + " \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "76a4648d-938a-435c-9ca0-918decdc988b", + "metadata": { + "id": "76a4648d-938a-435c-9ca0-918decdc988b" + }, + "outputs": [], + "source": [ + "output_states = output_states+token_embeddings_unnormalized" + ] + }, + { + "cell_type": "markdown", + "id": "a5129ff6-263e-4ab5-9bc8-f5396e457e63", + "metadata": { + "id": "a5129ff6-263e-4ab5-9bc8-f5396e457e63" + }, + "source": [ + "## Normalize and then run a feed forward neural network through the embedding delta\n", + "
\n", + " \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9b06228-a58b-4dde-9f30-78e307c1f9ef", + "metadata": { + "id": "a9b06228-a58b-4dde-9f30-78e307c1f9ef" + }, + "outputs": [], + "source": [ + "second_normalized = rms_norm(token_embeddings_unnormalized, model[\"model.layers.0.post_attention_layernorm.weight\"])" + ] + }, + { + "cell_type": "markdown", + "id": "104c6689-096f-4af2-9634-22ce3ee30369", + "metadata": { + "id": "104c6689-096f-4af2-9634-22ce3ee30369" + }, + "source": [ + "## Loading the ff weights and implementing the ffn\n", + "
\n", + " \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f999c5c9-b451-4c32-aae7-c5b7bc18e761", + "metadata": { + "id": "f999c5c9-b451-4c32-aae7-c5b7bc18e761" + }, + "outputs": [], + "source": [ + "w1 = model[f\"model.layers.0.mlp.gate_proj.weight\"]\n", + "w2 = model[f\"model.layers.0.mlp.down_proj.weight\"]\n", + "w3 = model[f\"model.layers.0.mlp.up_proj.weight\"]\n", + "output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(second_normalized, w1.T)) * torch.matmul(second_normalized, w3.T), w2.T)" + ] + }, + { + "cell_type": "markdown", + "id": "9b9e1f8e-f41a-4cfe-8979-21cb5b6e1fd4", + "metadata": { + "id": "9b9e1f8e-f41a-4cfe-8979-21cb5b6e1fd4" + }, + "source": [ + "## Everything is done!!!~\n", + "Now, run them at once!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a131a7af-666f-4c45-a5d5-3e093384a57b", + "metadata": { + "id": "a131a7af-666f-4c45-a5d5-3e093384a57b" + }, + "outputs": [], + "source": [ + "final_embedding = token_embeddings_unnormalized\n", + "x= 0\n", + "for layer in range(n_layers):\n", + " x+=1\n", + " residual1 = final_embedding\n", + "\n", + " # embeding norm\n", + " layer_embedding_norm = rms_norm(final_embedding, model[f\"model.layers.{layer}.input_layernorm.weight\"])\n", + "\n", + " q_layer = model[f\"model.layers.{layer}.self_attn.q_proj.weight\"]\n", + " k_layer = model[f\"model.layers.{layer}.self_attn.k_proj.weight\"]\n", + " v_layer = model[f\"model.layers.{layer}.self_attn.v_proj.weight\"]\n", + " w_layer = model[f\"model.layers.{layer}.self_attn.o_proj.weight\"]\n", + " q_layer_bias = model[f'model.layers.{layer}.self_attn.q_proj.bias']\n", + " k_layer_bias = model[f'model.layers.{layer}.self_attn.k_proj.bias']\n", + " v_layer_bias = model[f'model.layers.{layer}.self_attn.v_proj.bias']\n", + "\n", + " query_states = torch.matmul(layer_embedding_norm, q_layer.T)+q_layer_bias\n", + " key_states = torch.matmul(layer_embedding_norm, k_layer.T)+k_layer_bias\n", + " value_states = torch.matmul(layer_embedding_norm, v_layer.T)+v_layer_bias\n", + " head_dim = dim//n_heads\n", + " query_states = query_states.view(1, q_len, n_heads, head_dim).transpose(1, 2)\n", + " key_states = key_states.view(1, q_len, n_kv_heads, head_dim).transpose(1, 2)\n", + " value_states = value_states.view(1, q_len, n_kv_heads, head_dim).transpose(1, 2)\n", + "\n", + " cos, sin = rotary_emb(value_states, seq_len=q_len)\n", + " position_ids = torch.arange(q_len).view(1,q_len)\n", + " query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n", + "\n", + " key_states = repeat_kv(key_states, n_heads // n_kv_heads)\n", + " value_states = repeat_kv(value_states, n_heads // n_kv_heads)\n", + "\n", + " attn_output = torch.nn.functional.scaled_dot_product_attention(\n", + " query_states,\n", + " key_states,\n", + " value_states,\n", + " attn_mask=None,\n", + " dropout_p= 0.0,\n", + " # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.\n", + " is_causal= True,\n", + " )\n", + "\n", + "\n", + "\n", + " attn_output = attn_output.transpose(1, 2).contiguous()\n", + " attn_output = attn_output.view(1, q_len, dim)\n", + " output_states = torch.matmul(attn_output, w_layer.T)\n", + "\n", + " hidden_state = residual1+output_states\n", + "\n", + " # Fully connected\n", + " residual2 = hidden_state\n", + "\n", + " w1 = model[f\"model.layers.{layer}.mlp.gate_proj.weight\"]\n", + " w2 = model[f\"model.layers.{layer}.mlp.down_proj.weight\"]\n", + " w3 = model[f\"model.layers.{layer}.mlp.up_proj.weight\"]\n", + " second_normalized = rms_norm(hidden_state, model[f\"model.layers.{layer}.post_attention_layernorm.weight\"])\n", + " output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(second_normalized, w1.T)) * torch.matmul(second_normalized, w3.T), w2.T)\n", + " final_embedding = residual2+output_after_feedforward" + ] + }, + { + "cell_type": "markdown", + "id": "97719c73-4029-488a-88c1-2f58ace2b5e6", + "metadata": { + "id": "97719c73-4029-488a-88c1-2f58ace2b5e6" + }, + "source": [ + "## Here is the final embedding\n", + "The shape of it is same as the first embedding [10*3584].\n", + "
\n", + " \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e319e392-14f4-4875-bb7e-6ab1ff3f9e46", + "metadata": { + "id": "e319e392-14f4-4875-bb7e-6ab1ff3f9e46", + "outputId": "8d1bf0ea-adf5-4939-d35a-858deed79044" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 10, 3584])" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "final_normalized = rms_norm(final_embedding, model[\"model.norm.weight\"])\n", + "final_normalized.shape" + ] + }, + { + "cell_type": "markdown", + "id": "2c490e12-f1e9-4656-b67c-0b18fa652994", + "metadata": { + "id": "2c490e12-f1e9-4656-b67c-0b18fa652994" + }, + "source": [ + "## Finally!!! We can decode the embedding into the token value!\n", + "
\n", + " \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a59710bf-2aa5-4018-8e45-393c1af9923b", + "metadata": { + "id": "a59710bf-2aa5-4018-8e45-393c1af9923b", + "outputId": "309bc665-e547-4589-d0e9-d7057a210093" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([152064])" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "logits = torch.matmul(final_normalized[0][-1], model[\"lm_head.weight\"].T)\n", + "logits.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa370468-0279-4a77-ab89-259e082b068e", + "metadata": { + "id": "fa370468-0279-4a77-ab89-259e082b068e", + "outputId": "ccb50178-2d6a-452f-846b-ca89ac8d3ae8" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([55806])" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "next_token = torch.argmax(logits, dim=-1).view(1)\n", + "next_token" + ] + }, + { + "cell_type": "markdown", + "id": "8775ab65-bf3f-4cdb-ad15-53ad8d1ad8fe", + "metadata": { + "id": "8775ab65-bf3f-4cdb-ad15-53ad8d1ad8fe" + }, + "source": [ + "# Oh! yeah!~~~\n", + "
\n", + " \n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b543329b-3a85-4bf8-850e-2616583c6cdb", + "metadata": { + "id": "b543329b-3a85-4bf8-850e-2616583c6cdb", + "outputId": "5363cdaf-a5a0-4a0e-ee66-1b6086f06d2e" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'退'" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.decode(next_token)" + ] + }, + { + "cell_type": "markdown", + "id": "a3afc986-9c4a-4e69-a7dd-5c25c9a5eaab", + "metadata": { + "id": "a3afc986-9c4a-4e69-a7dd-5c25c9a5eaab" + }, + "source": [ + "# Here I need to appreciate Naklecha's Llama3 work\n", + "According to his **[Llama3-from-scratch](https://github.com/naklecha/llama3-from-scratch)**, I totally understand the structure of a decoder-only LLM.\n", + "\n", + "# In addition, I also want to broadcast Chinese LLM\n", + "Performence of Qwen2 has improved so much comparing to the previous version.\n", + "\n", + "# Help LLM beginner\n", + "Due to I'm not computer science graduates, I really meet so many problems. I hope my project can help these people who want to learn LLM.\n", + "\n", + "If you have any suggestions, plz and don't hesitate and contact me!!\n", + "\n", + "My RED num is 495668258" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29569b1d-46b8-4dfa-9676-78236e24c80b", + "metadata": { + "id": "29569b1d-46b8-4dfa-9676-78236e24c80b" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mistral", + "language": "python", + "name": "mistral" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.18" + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU", + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "26c723f59a96410489b443cbef0cb4a2": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_5aeeeb91317a4d91b5a4f04a6a92fe99", + "IPY_MODEL_29a762cdc7194ed5a9678abc3f55f496", + "IPY_MODEL_b50c06cfcbb74502a75ddec0e87e7157" + ], + "layout": "IPY_MODEL_474d5a23fe984a2fa80e13f0cbd187bb" + } + }, + "5aeeeb91317a4d91b5a4f04a6a92fe99": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d656cf9b3c414230828e38146096942c", + "placeholder": "​", + "style": "IPY_MODEL_e4f07a621e6b4eb29380afcda7af8ba1", + "value": "Loading checkpoint shards: 100%" + } + }, + "29a762cdc7194ed5a9678abc3f55f496": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bdc1353aea5b4a5c86d279efeb483f7f", + "max": 2, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_9ae197159ee44b629c2e0b440dae62cb", + "value": 2 + } + }, + "b50c06cfcbb74502a75ddec0e87e7157": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_68e131ba08a744f6b085d75bb59c674e", + "placeholder": "​", + "style": "IPY_MODEL_f451c98a8f1a4ee99380d30ed185855e", + "value": " 2/2 [00:01<00:00,  2.00it/s]" + } + }, + "474d5a23fe984a2fa80e13f0cbd187bb": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d656cf9b3c414230828e38146096942c": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e4f07a621e6b4eb29380afcda7af8ba1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "bdc1353aea5b4a5c86d279efeb483f7f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9ae197159ee44b629c2e0b440dae62cb": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "68e131ba08a744f6b085d75bb59c674e": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f451c98a8f1a4ee99380d30ed185855e": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file From b7cb31f46b35802a76bd1490b29e65ce7ba82a29 Mon Sep 17 00:00:00 2001 From: hkxhrwang Date: Sun, 12 Jan 2025 22:14:00 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E4=BD=BF=E7=94=A8=20Colab=20=E5=88=9B?= =?UTF-8?q?=E5=BB=BA=E8=80=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Deconstructing-QWen2-from-Ground-Up.ipynb | 73 ++++++++++++++++------- 1 file changed, 51 insertions(+), 22 deletions(-) diff --git a/Deconstructing-QWen2-from-Ground-Up.ipynb b/Deconstructing-QWen2-from-Ground-Up.ipynb index d4a037fc..f067f7b3 100644 --- a/Deconstructing-QWen2-from-Ground-Up.ipynb +++ b/Deconstructing-QWen2-from-Ground-Up.ipynb @@ -642,39 +642,51 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 27, "id": "b44c88a9-e06b-4a58-a4b7-62df0dd94668", "metadata": { "colab": { - "base_uri": "https://localhost:8080/", - "height": 279 + "base_uri": "https://localhost:8080/" }, "id": "b44c88a9-e06b-4a58-a4b7-62df0dd94668", - "outputId": "8c7ad846-b2f2-4392-97b7-994ae2bcf790" + "outputId": "069ac24b-ac29-48d3-ee26-e97cfd429867" }, "outputs": [ { - "output_type": "error", - "ename": "RuntimeError", - "evalue": "indices should be either on cpu or on the same device as the indexed tensor (cpu)", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mquery_states\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mapply_rotary_pos_emb\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquery_states\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey_states\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msin\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mposition_ids\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m\u001b[0m in \u001b[0;36mapply_rotary_pos_emb\u001b[0;34m(q, k, cos, sin, position_ids, unsqueeze_dim)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;31m`\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;31m`\u001b[0m \u001b[0mcomprising\u001b[0m \u001b[0mof\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mquery\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mkey\u001b[0m \u001b[0mtensors\u001b[0m \u001b[0mrotated\u001b[0m \u001b[0musing\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mRotary\u001b[0m \u001b[0mPosition\u001b[0m \u001b[0mEmbedding\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \"\"\"\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0mcos\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcos\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mposition_ids\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0munsqueeze_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0msin\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msin\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mposition_ids\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0munsqueeze_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0mq_embed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mq\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mcos\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mrotate_half\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mq\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0msin\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mRuntimeError\u001b[0m: indices should be either on cpu or on the same device as the indexed tensor (cpu)" - ] + "output_type": "execute_result", + "data": { + "text/plain": [ + "tensor([[[[ -3.2188, -0.1963, 2.9219, ..., 56.0000, 26.6250, 86.5000],\n", + " [ -4.5000, 2.2812, -1.6953, ..., 57.0000, 25.0000, 88.0000],\n", + " [ -1.2031, 3.8594, -3.0625, ..., 55.7500, 24.2500, 87.0000],\n", + " ...,\n", + " [ -4.5000, -3.3750, 3.6250, ..., 55.5000, 19.8750, 88.0000],\n", + " [ -2.5625, -0.3281, 2.7969, ..., 57.5000, 25.2500, 86.0000],\n", + " [ 1.4453, 2.5781, 1.5469, ..., 56.7500, 24.6250, 87.0000]],\n", + "\n", + " [[ -1.7969, -0.1318, 2.4844, ..., 46.5000, -67.0000, 94.0000],\n", + " [ 1.0156, 2.6250, 0.5938, ..., 46.7500, -68.0000, 91.0000],\n", + " [ 4.3125, 3.4062, -0.2812, ..., 47.7500, -69.0000, 91.5000],\n", + " ...,\n", + " [ 1.0938, -1.6172, 1.4219, ..., 47.0000, -68.5000, 91.0000],\n", + " [ 4.0625, 0.5703, 2.1406, ..., 47.2500, -69.0000, 91.0000],\n", + " [ 3.0938, 2.5781, 2.2656, ..., 46.7500, -68.0000, 91.5000]]]],\n", + " device='cuda:0', dtype=torch.bfloat16)" + ] + }, + "metadata": {}, + "execution_count": 27 } ], "source": [ - "query_states, key_states = apply_rotary_pos_emb(query_states.cpu(), key_states.cpu(), cos, sin, position_ids.to)\n", - "query_states.to(device)\n", - "key_states.to(device)" + "query_states, key_states = apply_rotary_pos_emb(query_states.cpu(), key_states.cpu(), cos, sin, position_ids)\n", + "query_states = query_states.to(device)\n", + "key_states = key_states.to(device)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 28, "id": "ad4b29e6-d009-4aa0-b968-9d78fd06492c", "metadata": { "id": "ad4b29e6-d009-4aa0-b968-9d78fd06492c" @@ -695,7 +707,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 29, "id": "85b7fb25-cb30-4a13-8da0-1393495b17c1", "metadata": { "id": "85b7fb25-cb30-4a13-8da0-1393495b17c1" @@ -723,12 +735,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 30, "id": "2174fff2-0989-4b2c-bfc9-7e5e37e47db9", "metadata": { - "id": "2174fff2-0989-4b2c-bfc9-7e5e37e47db9" + "colab": { + "base_uri": "https://localhost:8080/", + "height": 223 + }, + "id": "2174fff2-0989-4b2c-bfc9-7e5e37e47db9", + "outputId": "717227ba-a852-4501-cf78-787e3c253eb3" }, - "outputs": [], + "outputs": [ + { + "output_type": "error", + "ename": "RuntimeError", + "evalue": "Expected query, key, and value to have the same device type, but got query.device: cpu key.device: cpu and value.device: cuda:0 instead.", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m attn_output = torch.nn.functional.scaled_dot_product_attention(\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mquery_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mkey_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mvalue_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mattn_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: Expected query, key, and value to have the same device type, but got query.device: cpu key.device: cpu and value.device: cuda:0 instead." + ] + } + ], "source": [ "attn_output = torch.nn.functional.scaled_dot_product_attention(\n", " query_states,\n", From 1c9ae9b2c537fa9548532189ac9c24db97256bd6 Mon Sep 17 00:00:00 2001 From: hkxhrwang Date: Sun, 12 Jan 2025 22:37:45 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E4=BD=BF=E7=94=A8=20Colab=20=E5=88=9B?= =?UTF-8?q?=E5=BB=BA=E8=80=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Deconstructing-QWen2-from-Ground-Up.ipynb | 380 ++++++++++++++-------- 1 file changed, 242 insertions(+), 138 deletions(-) diff --git a/Deconstructing-QWen2-from-Ground-Up.ipynb b/Deconstructing-QWen2-from-Ground-Up.ipynb index f067f7b3..ec29db77 100644 --- a/Deconstructing-QWen2-from-Ground-Up.ipynb +++ b/Deconstructing-QWen2-from-Ground-Up.ipynb @@ -445,12 +445,30 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 53, "id": "da9dbcd2-b516-4e55-b998-712e2abede87", "metadata": { - "id": "da9dbcd2-b516-4e55-b998-712e2abede87" + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "da9dbcd2-b516-4e55-b998-712e2abede87", + "outputId": "b3dc54e0-d569-4483-c627-b9f91624dd67" }, - "outputs": [], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "q_layer0.shape: torch.Size([2048, 2048])\n", + "k_layer0.shape: torch.Size([256, 2048])\n", + "v_layer0.shape: torch.Size([256, 2048])\n", + "o_layer0.shape: torch.Size([2048, 2048])\n", + "q_layer0_bias.shape: torch.Size([2048])\n", + "k_layer0_bias.shape: torch.Size([256])\n", + "v_layer0_bias.shape: torch.Size([256])\n" + ] + } + ], "source": [ "q_layer0 = model[\"model.layers.0.self_attn.q_proj.weight\"]\n", "k_layer0 = model[\"model.layers.0.self_attn.k_proj.weight\"]\n", @@ -458,7 +476,14 @@ "o_layer0 = model[\"model.layers.0.self_attn.o_proj.weight\"]\n", "q_layer0_bias = model['model.layers.0.self_attn.q_proj.bias']\n", "k_layer0_bias = model['model.layers.0.self_attn.k_proj.bias']\n", - "v_layer0_bias = model['model.layers.0.self_attn.v_proj.bias']" + "v_layer0_bias = model['model.layers.0.self_attn.v_proj.bias']\n", + "print(f\"q_layer0.shape: {q_layer0.shape}\")\n", + "print(f\"k_layer0.shape: {k_layer0.shape}\")\n", + "print(f\"v_layer0.shape: {v_layer0.shape}\")\n", + "print(f\"o_layer0.shape: {o_layer0.shape}\")\n", + "print(f\"q_layer0_bias.shape: {q_layer0_bias.shape}\")\n", + "print(f\"k_layer0_bias.shape: {k_layer0_bias.shape}\")\n", + "print(f\"v_layer0_bias.shape: {v_layer0_bias.shape}\")" ] }, { @@ -474,21 +499,38 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 54, "id": "a6d3af1d-b3df-420b-8903-0324e70d5d20", "metadata": { - "id": "a6d3af1d-b3df-420b-8903-0324e70d5d20" + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a6d3af1d-b3df-420b-8903-0324e70d5d20", + "outputId": "03824e0e-17f0-470d-b73b-81af905e5863" }, - "outputs": [], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "query_states.shape: torch.Size([10, 2048])\n", + "key_states.shape: torch.Size([10, 256])\n", + "value_states.shape: torch.Size([10, 256])\n" + ] + } + ], "source": [ "query_states = torch.matmul(token_embeddings, q_layer0.T)+q_layer0_bias\n", "key_states = torch.matmul(token_embeddings, k_layer0.T)+k_layer0_bias\n", - "value_states = torch.matmul(token_embeddings, v_layer0.T)+v_layer0_bias" + "value_states = torch.matmul(token_embeddings, v_layer0.T)+v_layer0_bias\n", + "print(f\"query_states.shape: {query_states.shape}\")\n", + "print(f\"key_states.shape: {key_states.shape}\")\n", + "print(f\"value_states.shape: {value_states.shape}\")" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 55, "id": "d7c1baf0-e830-4265-8ff0-8d6181d964e5", "metadata": { "id": "d7c1baf0-e830-4265-8ff0-8d6181d964e5" @@ -522,7 +564,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 56, "id": "965d1db5-5fb2-46f8-a276-1ae5567a90be", "metadata": { "id": "965d1db5-5fb2-46f8-a276-1ae5567a90be" @@ -585,7 +627,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 57, "id": "d53e5a0b-aa06-473b-81cb-cbc9892b2574", "metadata": { "id": "d53e5a0b-aa06-473b-81cb-cbc9892b2574" @@ -629,64 +671,66 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 58, "id": "be770e30-e653-4b4b-84ed-110a19533660", "metadata": { - "id": "be770e30-e653-4b4b-84ed-110a19533660" + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "be770e30-e653-4b4b-84ed-110a19533660", + "outputId": "1db438b7-b841-449d-ee10-8196d1399003" }, - "outputs": [], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "cos.shape: torch.Size([10, 128])\n", + "sin.shape: torch.Size([10, 128])\n", + "position_ids.shape: torch.Size([1, 10])\n" + ] + } + ], "source": [ "cos, sin = rotary_emb(value_states, seq_len=q_len)\n", - "position_ids = torch.arange(q_len).view(1,q_len)" + "position_ids = torch.arange(q_len).view(1,q_len)\n", + "print(f\"cos.shape: {cos.shape}\")\n", + "print(f\"sin.shape: {sin.shape}\")\n", + "print(f\"position_ids.shape: {position_ids.shape}\")" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 59, "id": "b44c88a9-e06b-4a58-a4b7-62df0dd94668", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "b44c88a9-e06b-4a58-a4b7-62df0dd94668", - "outputId": "069ac24b-ac29-48d3-ee26-e97cfd429867" + "outputId": "69f0e5ce-9c4a-40d4-b5d2-bc15ec50452c" }, "outputs": [ { - "output_type": "execute_result", - "data": { - "text/plain": [ - "tensor([[[[ -3.2188, -0.1963, 2.9219, ..., 56.0000, 26.6250, 86.5000],\n", - " [ -4.5000, 2.2812, -1.6953, ..., 57.0000, 25.0000, 88.0000],\n", - " [ -1.2031, 3.8594, -3.0625, ..., 55.7500, 24.2500, 87.0000],\n", - " ...,\n", - " [ -4.5000, -3.3750, 3.6250, ..., 55.5000, 19.8750, 88.0000],\n", - " [ -2.5625, -0.3281, 2.7969, ..., 57.5000, 25.2500, 86.0000],\n", - " [ 1.4453, 2.5781, 1.5469, ..., 56.7500, 24.6250, 87.0000]],\n", - "\n", - " [[ -1.7969, -0.1318, 2.4844, ..., 46.5000, -67.0000, 94.0000],\n", - " [ 1.0156, 2.6250, 0.5938, ..., 46.7500, -68.0000, 91.0000],\n", - " [ 4.3125, 3.4062, -0.2812, ..., 47.7500, -69.0000, 91.5000],\n", - " ...,\n", - " [ 1.0938, -1.6172, 1.4219, ..., 47.0000, -68.5000, 91.0000],\n", - " [ 4.0625, 0.5703, 2.1406, ..., 47.2500, -69.0000, 91.0000],\n", - " [ 3.0938, 2.5781, 2.2656, ..., 46.7500, -68.0000, 91.5000]]]],\n", - " device='cuda:0', dtype=torch.bfloat16)" - ] - }, - "metadata": {}, - "execution_count": 27 + "output_type": "stream", + "name": "stdout", + "text": [ + "query_states.shape: torch.Size([1, 16, 10, 128])\n", + "key_states.shape: torch.Size([1, 2, 10, 128])\n" + ] } ], "source": [ "query_states, key_states = apply_rotary_pos_emb(query_states.cpu(), key_states.cpu(), cos, sin, position_ids)\n", "query_states = query_states.to(device)\n", - "key_states = key_states.to(device)" + "key_states = key_states.to(device)\n", + "print(f\"query_states.shape: {query_states.shape}\")\n", + "print(f\"key_states.shape: {key_states.shape}\")" ] }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 60, "id": "ad4b29e6-d009-4aa0-b968-9d78fd06492c", "metadata": { "id": "ad4b29e6-d009-4aa0-b968-9d78fd06492c" @@ -707,15 +751,30 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 61, "id": "85b7fb25-cb30-4a13-8da0-1393495b17c1", "metadata": { - "id": "85b7fb25-cb30-4a13-8da0-1393495b17c1" + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "85b7fb25-cb30-4a13-8da0-1393495b17c1", + "outputId": "08c661dd-22b5-45f0-caf9-058f71de9a63" }, - "outputs": [], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "key_states.shape: torch.Size([1, 16, 10, 128])\n", + "value_states.shape: torch.Size([1, 16, 10, 128])\n" + ] + } + ], "source": [ "key_states = repeat_kv(key_states, n_heads // n_kv_heads)\n", - "value_states = repeat_kv(value_states, n_heads // n_kv_heads)" + "value_states = repeat_kv(value_states, n_heads // n_kv_heads)\n", + "print(f\"key_states.shape: {key_states.shape}\")\n", + "print(f\"value_states.shape: {value_states.shape}\")" ] }, { @@ -735,26 +794,21 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 64, "id": "2174fff2-0989-4b2c-bfc9-7e5e37e47db9", "metadata": { "colab": { - "base_uri": "https://localhost:8080/", - "height": 223 + "base_uri": "https://localhost:8080/" }, "id": "2174fff2-0989-4b2c-bfc9-7e5e37e47db9", - "outputId": "717227ba-a852-4501-cf78-787e3c253eb3" + "outputId": "0cd15e59-82f3-4679-f9b7-f96308480e81" }, "outputs": [ { - "output_type": "error", - "ename": "RuntimeError", - "evalue": "Expected query, key, and value to have the same device type, but got query.device: cpu key.device: cpu and value.device: cuda:0 instead.", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m attn_output = torch.nn.functional.scaled_dot_product_attention(\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mquery_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mkey_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mvalue_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mattn_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mRuntimeError\u001b[0m: Expected query, key, and value to have the same device type, but got query.device: cpu key.device: cpu and value.device: cuda:0 instead." + "output_type": "stream", + "name": "stdout", + "text": [ + "attn_output.shape: torch.Size([1, 16, 10, 128])\n" ] } ], @@ -767,57 +821,59 @@ " dropout_p= 0.0,\n", " # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.\n", " is_causal= True,\n", - ")" + ")\n", + "print(f\"attn_output.shape: {attn_output.shape}\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 65, "id": "c69e00a9-2373-4937-b2d3-ed3eae4481b3", "metadata": { - "id": "c69e00a9-2373-4937-b2d3-ed3eae4481b3" + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "c69e00a9-2373-4937-b2d3-ed3eae4481b3", + "outputId": "a1c4bfbe-6da9-4530-b63a-afa03322c126" }, - "outputs": [], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "attn_output.shape: torch.Size([1, 10, 2048])\n" + ] + } + ], "source": [ "attn_output = attn_output.transpose(1, 2).contiguous()\n", - "attn_output = attn_output.view(1, q_len, dim)" + "attn_output = attn_output.view(1, q_len, dim)\n", + "print(f\"attn_output.shape: {attn_output.shape}\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 66, "id": "0d3dd439-a258-48a6-9e1d-13829c94a078", "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "0d3dd439-a258-48a6-9e1d-13829c94a078", - "outputId": "20b052da-792e-4b78-ed48-8a1547834262" + "outputId": "bd35840d-21d4-45f7-9de9-76f8b5014a8b" }, "outputs": [ { - "data": { - "text/plain": [ - "tensor([[[-1.8568e-01, 1.3149e-01, -1.5167e-01, ..., 2.8487e-02,\n", - " -8.3742e-02, -2.3384e-02],\n", - " [-1.2537e-01, 2.0195e-01, -1.2300e-02, ..., -3.6986e-02,\n", - " -1.8594e-01, 9.9794e-02],\n", - " [-1.4426e-01, 1.5807e-01, -1.7747e-01, ..., -7.1516e-02,\n", - " 7.0311e-02, -1.7331e-01],\n", - " ...,\n", - " [-5.9189e-02, 4.0363e-02, -1.3974e-05, ..., -5.2831e-02,\n", - " -2.0385e-02, 8.6324e-03],\n", - " [ 3.7043e-02, 5.2902e-02, 3.0693e-03, ..., -8.9145e-02,\n", - " -1.0277e-01, 1.0480e-02],\n", - " [-8.8573e-02, 1.8764e-02, -4.4170e-02, ..., 1.4842e-01,\n", - " -9.0892e-02, 5.9852e-02]]])" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" + "output_type": "stream", + "name": "stdout", + "text": [ + "output_states.shape: torch.Size([1, 10, 2048])\n" + ] } ], "source": [ "output_states = torch.matmul(attn_output, o_layer0.T)\n", - "output_states" + "print(f\"output_states.shape: {output_states.shape}\")" ] }, { @@ -836,14 +892,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 67, "id": "76a4648d-938a-435c-9ca0-918decdc988b", "metadata": { - "id": "76a4648d-938a-435c-9ca0-918decdc988b" + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "76a4648d-938a-435c-9ca0-918decdc988b", + "outputId": "e8b6ca24-abbf-4a05-b021-aa71c9e4a2e2" }, - "outputs": [], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "output_states.shape: torch.Size([1, 10, 2048])\n" + ] + } + ], "source": [ - "output_states = output_states+token_embeddings_unnormalized" + "output_states = output_states+token_embeddings_unnormalized\n", + "print(f\"output_states.shape: {output_states.shape}\")" ] }, { @@ -861,14 +930,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 68, "id": "a9b06228-a58b-4dde-9f30-78e307c1f9ef", "metadata": { - "id": "a9b06228-a58b-4dde-9f30-78e307c1f9ef" + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a9b06228-a58b-4dde-9f30-78e307c1f9ef", + "outputId": "117456af-d249-4df8-8902-5815168a66ad" }, - "outputs": [], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "second_normalized.shape: torch.Size([10, 2048])\n" + ] + } + ], "source": [ - "second_normalized = rms_norm(token_embeddings_unnormalized, model[\"model.layers.0.post_attention_layernorm.weight\"])" + "second_normalized = rms_norm(token_embeddings_unnormalized, model[\"model.layers.0.post_attention_layernorm.weight\"])\n", + "print(f\"second_normalized.shape: {second_normalized.shape}\")" ] }, { @@ -886,17 +968,30 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 69, "id": "f999c5c9-b451-4c32-aae7-c5b7bc18e761", "metadata": { - "id": "f999c5c9-b451-4c32-aae7-c5b7bc18e761" + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "f999c5c9-b451-4c32-aae7-c5b7bc18e761", + "outputId": "131744a9-4b0c-4f96-8e0c-2c4f0b508f16" }, - "outputs": [], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "output_after_feedforward.shape: torch.Size([10, 2048])\n" + ] + } + ], "source": [ "w1 = model[f\"model.layers.0.mlp.gate_proj.weight\"]\n", "w2 = model[f\"model.layers.0.mlp.down_proj.weight\"]\n", "w3 = model[f\"model.layers.0.mlp.up_proj.weight\"]\n", - "output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(second_normalized, w1.T)) * torch.matmul(second_normalized, w3.T), w2.T)" + "output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(second_normalized, w1.T)) * torch.matmul(second_normalized, w3.T), w2.T)\n", + "print(f\"output_after_feedforward.shape: {output_after_feedforward.shape}\")" ] }, { @@ -912,7 +1007,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 71, "id": "a131a7af-666f-4c45-a5d5-3e093384a57b", "metadata": { "id": "a131a7af-666f-4c45-a5d5-3e093384a57b" @@ -946,7 +1041,9 @@ "\n", " cos, sin = rotary_emb(value_states, seq_len=q_len)\n", " position_ids = torch.arange(q_len).view(1,q_len)\n", - " query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n", + " query_states, key_states = apply_rotary_pos_emb(query_states.cpu(), key_states.cpu(), cos, sin, position_ids)\n", + " query_states = query_states.to(device)\n", + " key_states = key_states.to(device)\n", "\n", " key_states = repeat_kv(key_states, n_heads // n_kv_heads)\n", " value_states = repeat_kv(value_states, n_heads // n_kv_heads)\n", @@ -996,27 +1093,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 72, "id": "e319e392-14f4-4875-bb7e-6ab1ff3f9e46", "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "e319e392-14f4-4875-bb7e-6ab1ff3f9e46", - "outputId": "8d1bf0ea-adf5-4939-d35a-858deed79044" + "outputId": "c74f4316-8725-4753-b897-26759df0daca" }, "outputs": [ { - "data": { - "text/plain": [ - "torch.Size([1, 10, 3584])" - ] - }, - "execution_count": 51, - "metadata": {}, - "output_type": "execute_result" + "output_type": "stream", + "name": "stdout", + "text": [ + "final_normalized.shape: torch.Size([1, 10, 2048])\n" + ] } ], "source": [ "final_normalized = rms_norm(final_embedding, model[\"model.norm.weight\"])\n", - "final_normalized.shape" + "print(f\"final_normalized.shape: {final_normalized.shape}\")" ] }, { @@ -1034,52 +1131,52 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 73, "id": "a59710bf-2aa5-4018-8e45-393c1af9923b", "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "a59710bf-2aa5-4018-8e45-393c1af9923b", - "outputId": "309bc665-e547-4589-d0e9-d7057a210093" + "outputId": "c8e57fb4-5060-4538-a779-0db809dda4a8" }, "outputs": [ { - "data": { - "text/plain": [ - "torch.Size([152064])" - ] - }, - "execution_count": 52, - "metadata": {}, - "output_type": "execute_result" + "output_type": "stream", + "name": "stdout", + "text": [ + "logits.shape: torch.Size([151936])\n" + ] } ], "source": [ "logits = torch.matmul(final_normalized[0][-1], model[\"lm_head.weight\"].T)\n", - "logits.shape" + "print(f\"logits.shape: {logits.shape}\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 74, "id": "fa370468-0279-4a77-ab89-259e082b068e", "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "fa370468-0279-4a77-ab89-259e082b068e", - "outputId": "ccb50178-2d6a-452f-846b-ca89ac8d3ae8" + "outputId": "38bc4c8f-1d05-4efb-82e9-68b3957714f4" }, "outputs": [ { - "data": { - "text/plain": [ - "tensor([55806])" - ] - }, - "execution_count": 53, - "metadata": {}, - "output_type": "execute_result" + "output_type": "stream", + "name": "stdout", + "text": [ + "next_token: tensor([55806], device='cuda:0')\n" + ] } ], "source": [ "next_token = torch.argmax(logits, dim=-1).view(1)\n", - "next_token" + "print(f\"next_token: {next_token}\")" ] }, { @@ -1097,22 +1194,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 75, "id": "b543329b-3a85-4bf8-850e-2616583c6cdb", "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 36 + }, "id": "b543329b-3a85-4bf8-850e-2616583c6cdb", - "outputId": "5363cdaf-a5a0-4a0e-ee66-1b6086f06d2e" + "outputId": "ea580b21-08cd-4c9b-a791-bb4ce024ab48" }, "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "'退'" - ] + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + } }, - "execution_count": 54, "metadata": {}, - "output_type": "execute_result" + "execution_count": 75 } ], "source": [