From c05bfdb10fbdcb8cb513007b8fb0d3fa05777059 Mon Sep 17 00:00:00 2001 From: GoHomeToMacDonal Date: Thu, 5 Oct 2023 06:11:26 +0800 Subject: [PATCH 1/6] ChatGLM2 Support --- vllm/config.py | 4 + vllm/model_executor/__init__.py | 2 +- vllm/model_executor/model_loader.py | 1 + vllm/model_executor/models/__init__.py | 2 + vllm/model_executor/models/chatglm.py | 403 ++++++++++++++++++++ vllm/transformers_utils/config.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/chatglm.py | 66 ++++ 8 files changed, 480 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/models/chatglm.py create mode 100644 vllm/transformers_utils/configs/chatglm.py diff --git a/vllm/config.py b/vllm/config.py index 7a9417985952..e28ad3be8e61 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -166,6 +166,10 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: if getattr(self.hf_config, "num_key_value_heads", None) is not None: return (self.hf_config.num_key_value_heads // parallel_config.tensor_parallel_size) + # For ChatGLM-2: + if getattr(self.hf_config, "multi_query_group_num", None) is not None: + return (self.hf_config.multi_query_group_num // + parallel_config.tensor_parallel_size) total_num_attention_heads = self.hf_config.num_attention_heads return total_num_attention_heads // parallel_config.tensor_parallel_size diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 36fc30f9c1e3..e65b151168d7 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -6,4 +6,4 @@ "InputMetadata", "get_model", "set_random_seed", -] +] \ No newline at end of file diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 951ba1f0ceba..b1ee1f257d93 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -30,6 +30,7 @@ "OPTForCausalLM": OPTForCausalLM, "QWenLMHeadModel": QWenLMHeadModel, "RWForCausalLM": FalconForCausalLM, + "ChatGLMModel": ChatGLMForCausalLM } # FIXME(woosuk): Remove this once all models support quantization. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 01d85355b297..e8a7106a349f 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -13,6 +13,7 @@ from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.qwen import QWenLMHeadModel from vllm.model_executor.models.mistral import MistralForCausalLM +from vllm.model_executor.models.chatglm import ChatGLMForCausalLM __all__ = [ "AquilaForCausalLM", @@ -30,4 +31,5 @@ "OPTForCausalLM", "QWenLMHeadModel", "MistralForCausalLM", + "ChatGLMForCausalLM" ] diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py new file mode 100644 index 000000000000..f19804095e86 --- /dev/null +++ b/vllm/model_executor/models/chatglm.py @@ -0,0 +1,403 @@ +# coding=utf-8 +# Adapted from +# https://github.com/THUDM/ChatGLM2-6B +"""Inference-only ChatGLM model compatible with THUDM weights. + +The input of the model is flattened to a 1D tensor of tokens. The model uses +InputMetadata to extract the original 2D shape of the input. +""" +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +from torch.nn import LayerNorm + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.weight_utils import ( + hf_model_weights_iterator, + load_tensor_parallel_weights, +) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding +from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, + RowParallelLinear) +from vllm.sequence import SequenceOutputs + +from vllm.transformers_utils.configs import ChatGLMConfig + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class GLMAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.multi_query_attention = config.multi_query_attention + self.total_num_kv_heads = config.multi_query_group_num if config.multi_query_attention else config.num_attention_heads + assert self.total_num_kv_heads % tp_size == 0 + self.num_kv_heads = self.total_num_kv_heads // tp_size + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.query_key_value = ColumnParallelLinear( + config.hidden_size, + (self.total_num_heads + 2 * self.total_num_kv_heads) * + self.head_dim, + bias=config.add_qkv_bias, + gather_output=False, + ) + self.dense = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=config.add_bias_linear, + input_is_parallel=True, + ) + + self.attn = PagedAttentionWithRoPE( + self.num_heads, + self.head_dim, + self.scaling, + rotary_dim=self.head_dim // 2, + num_kv_heads=self.num_kv_heads, + is_neox_style=False, + # is_glm_style=True + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + key_cache, value_cache = kv_cache + + context_layer = self.attn( + position_ids, + q, + k, + v, + key_cache, + value_cache, + input_metadata, + cache_event, + ) + + + attn_output, _ = self.dense(context_layer) + + return attn_output + + +class GLMMLP(nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config): + super(GLMMLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=config.add_bias_linear, + gather_output=False, + ) + + self.activation_func = SiluAndMul() + + # Project back to h. + self.dense_4h_to_h = RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + bias=config.add_bias_linear, + input_is_parallel=True, + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel, _ = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output, _ = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__( + self, + config, + ): + super(GLMBlock, self).__init__() + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm + ) + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc( + config.hidden_size, eps=config.layernorm_epsilon + ) + + # Self attention. + self.self_attention = GLMAttention(config) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc( + config.hidden_size, eps=config.layernorm_epsilon + ) + + # MLP + self.mlp = GLMMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + # hidden_states: [num_tokens, h] + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output = self.self_attention( + hidden_states=layernorm_output, + position_ids=position_ids, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = residual + attention_output + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = self.mlp(layernorm_output) + residual + + return output + + +class GLMTransformer(nn.Module): + """Transformer class.""" + + def __init__(self, config): + super(GLMTransformer, self).__init__() + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + self.layers = nn.ModuleList([GLMBlock(config) for i in range(self.num_layers)]) + + if self.post_layer_norm: + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = LayerNormFunc( + config.hidden_size, eps=config.layernorm_epsilon + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + for i in range(self.num_layers): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.layers[i] + hidden_states = layer( + hidden_states=hidden_states, + position_ids=position_ids, + kv_cache=kv_caches[i], + input_metadata=input_metadata, + cache_event=cache_event, + ) + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class ChatGLMModel(nn.Module): + def __init__(self, config): + super().__init__() + + self.embedding = VocabParallelEmbedding( + config.padded_vocab_size, config.hidden_size + ) + + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + self.encoder = GLMTransformer(config) + + self.output_layer = ColumnParallelLinear( + config.hidden_size, + config.padded_vocab_size, + bias=False, + gather_output=False, + params_dtype=config.torch_dtype, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ): + inputs_embeds = self.embedding(input_ids) + + # Run encoder. + hidden_states = self.encoder( + hidden_states=inputs_embeds, + position_ids=position_ids, + kv_caches=kv_caches, + input_metadata=input_metadata, + cache_events=cache_events, + ) + + return hidden_states + + +class ChatGLMForCausalLM(nn.Module): + def __init__(self, config: ChatGLMConfig): + super().__init__() + self.config: ChatGLMConfig = config + self.transformer = ChatGLMModel(config) + self.lm_head_weight = self.transformer.output_layer.weight + self.sampler = Sampler(config.padded_vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.transformer( + input_ids, positions, kv_caches, input_metadata, cache_events + ) + next_tokens = self.sampler(self.lm_head_weight, hidden_states, input_metadata) + return next_tokens + + _column_parallel_weights = [ + "output_layer.weight", + "embedding.weight", + ] + _row_parallel_weights = [ + "dense_4h_to_h", + "self_attention.dense" + ] + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + q_proj_shard_size = (self.config.hidden_size // tp_size) + kv_proj_shard_size = (self.config.hidden_size // + self.config.num_attention_heads * + self.config.multi_query_group_num // tp_size) + + mlp_hidden_shard_size = self.config.ffn_hidden_size // tp_size + + state_dict = self.state_dict() + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision + ): + if "word_embeddings" in name: + name = name.replace(".word_embeddings", "") + + if name in state_dict: + param = state_dict[name] + if "query_key_value" in name: + q_offset = q_proj_shard_size * tp_rank + k_offset = q_proj_shard_size * tp_size + kv_proj_shard_size * tp_rank + v_offset = q_proj_shard_size * tp_size + kv_proj_shard_size * (tp_size + tp_rank) + wq = loaded_weight[q_offset:q_offset + q_proj_shard_size] + wk = loaded_weight[k_offset:k_offset + kv_proj_shard_size] + wv = loaded_weight[v_offset:v_offset + kv_proj_shard_size] + loaded_weight = torch.cat([wq, wk, wv], dim=0) + param.data.copy_(loaded_weight) + continue + + if "dense_h_to_4h" in name: + w_gate = loaded_weight[mlp_hidden_shard_size * tp_rank:mlp_hidden_shard_size * (tp_rank + 1)] + w_proj = loaded_weight[mlp_hidden_shard_size * (tp_size + tp_rank):mlp_hidden_shard_size * (tp_size + tp_rank + 1)] + loaded_weight = torch.cat([w_gate, w_proj], dim=0) + param.data.copy_(loaded_weight) + continue + + load_tensor_parallel_weights( + param, + loaded_weight, + name, + self._column_parallel_weights, + self._row_parallel_weights, + tp_rank, + ) + elif name == 'transformer.rotary_pos_emb.inv_freq': + continue + else: + print("Warning never found tensor's name:", name) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index a1efbedb6895..84dc0df46856 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -11,6 +11,7 @@ "qwen": QWenConfig, "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) + "chatglm": ChatGLMConfig, } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 3955c772b7b3..7b72d9527d32 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -7,6 +7,7 @@ # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.mistral import MistralConfig +from vllm.transformers_utils.configs.chatglm import ChatGLMConfig __all__ = [ "MPTConfig", @@ -15,4 +16,5 @@ "QWenConfig", "RWConfig", "MistralConfig", + "ChatGLMConfig" ] diff --git a/vllm/transformers_utils/configs/chatglm.py b/vllm/transformers_utils/configs/chatglm.py new file mode 100644 index 000000000000..76faf5aa2292 --- /dev/null +++ b/vllm/transformers_utils/configs/chatglm.py @@ -0,0 +1,66 @@ +# coding=utf-8 +# Adapted from +# https://github.com/THUDM/ChatGLM2-6B +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + attribute_map = { + "num_hidden_layers": "num_layers", + "n_head_kv": "multi_query_group_num", + } + + def __init__(self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + interleaved_qkv=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + super().__init__(**kwargs) From 9ac35d587e9daa614f4b77542f99efd142c6ce49 Mon Sep 17 00:00:00 2001 From: GoHomeToMacDonal <143197337+GoHomeToMacDonal@users.noreply.github.com> Date: Sat, 4 Nov 2023 12:52:00 +0800 Subject: [PATCH 2/6] Update __init__.py --- vllm/model_executor/models/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index acc689f79d17..ea83720982a6 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -13,7 +13,6 @@ from vllm.model_executor.models.mpt import MptForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.qwen import QWenLMHeadModel -from vllm.model_executor.models.mistral import MistralForCausalLM from vllm.model_executor.models.chatglm import ChatGLMForCausalLM __all__ = [ From 302d35d6f672aa3ecac608d2341ba6f980f3e253 Mon Sep 17 00:00:00 2001 From: GoHomeToMacDonal <143197337+GoHomeToMacDonal@users.noreply.github.com> Date: Sat, 4 Nov 2023 13:42:39 +0800 Subject: [PATCH 3/6] Update __init__.py --- vllm/transformers_utils/configs/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 982fa7b9677b..7bd54e9afd2a 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -5,7 +5,6 @@ # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig -from vllm.transformers_utils.configs.mistral import MistralConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig __all__ = [ @@ -13,6 +12,5 @@ "AquilaConfig", "QWenConfig", "RWConfig", - "MistralConfig", "ChatGLMConfig" ] From 2a8e01d85393969d590e0829b22804de113e9f49 Mon Sep 17 00:00:00 2001 From: zhinanertui Date: Sat, 4 Nov 2023 05:49:07 +0000 Subject: [PATCH 4/6] Code formatting --- vllm/model_executor/models/__init__.py | 20 ++---- vllm/model_executor/models/chatglm.py | 77 +++++++++++---------- vllm/transformers_utils/configs/__init__.py | 6 +- vllm/transformers_utils/configs/chatglm.py | 3 +- 4 files changed, 49 insertions(+), 57 deletions(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index ea83720982a6..5d5bb7ca0ace 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -16,20 +16,10 @@ from vllm.model_executor.models.chatglm import ChatGLMForCausalLM __all__ = [ - "AquilaForCausalLM", - "BaiChuanForCausalLM", - "BaichuanForCausalLM", - "BloomForCausalLM", - "FalconForCausalLM", - "GPT2LMHeadModel", - "GPTBigCodeForCausalLM", - "GPTJForCausalLM", - "GPTNeoXForCausalLM", - "InternLMForCausalLM", - "LlamaForCausalLM", - "MptForCausalLM", - "OPTForCausalLM", - "QWenLMHeadModel", - "MistralForCausalLM", + "AquilaForCausalLM", "BaiChuanForCausalLM", "BaichuanForCausalLM", + "BloomForCausalLM", "FalconForCausalLM", "GPT2LMHeadModel", + "GPTBigCodeForCausalLM", "GPTJForCausalLM", "GPTNeoXForCausalLM", + "InternLMForCausalLM", "LlamaForCausalLM", "MptForCausalLM", + "OPTForCausalLM", "QWenLMHeadModel", "MistralForCausalLM", "ChatGLMForCausalLM" ] diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index f19804095e86..caea0f8c1ceb 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -26,8 +26,10 @@ get_tensor_model_parallel_world_size, ) from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding -from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.parallel_utils.layers import ( + ColumnParallelLinear, + RowParallelLinear, +) from vllm.sequence import SequenceOutputs from vllm.transformers_utils.configs import ChatGLMConfig @@ -36,6 +38,7 @@ class GLMAttention(nn.Module): + def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size @@ -44,7 +47,9 @@ def __init__(self, config): assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.multi_query_attention = config.multi_query_attention - self.total_num_kv_heads = config.multi_query_group_num if config.multi_query_attention else config.num_attention_heads + self.total_num_kv_heads = (config.multi_query_group_num + if config.multi_query_attention else + config.num_attention_heads) assert self.total_num_kv_heads % tp_size == 0 self.num_kv_heads = self.total_num_kv_heads // tp_size self.head_dim = config.hidden_size // self.total_num_heads @@ -99,7 +104,6 @@ def forward( cache_event, ) - attn_output, _ = self.dense(context_layer) return attn_output @@ -118,7 +122,7 @@ def __init__(self, config): self.add_bias = config.add_bias_linear - # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + # Project to 4h. self.dense_h_to_4h = ColumnParallelLinear( config.hidden_size, config.ffn_hidden_size * 2, @@ -158,16 +162,14 @@ def __init__( ): super(GLMBlock, self).__init__() self.apply_residual_connection_post_layernorm = ( - config.apply_residual_connection_post_layernorm - ) + config.apply_residual_connection_post_layernorm) self.fp32_residual_connection = config.fp32_residual_connection LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm # Layernorm on the input data. - self.input_layernorm = LayerNormFunc( - config.hidden_size, eps=config.layernorm_epsilon - ) + self.input_layernorm = LayerNormFunc(config.hidden_size, + eps=config.layernorm_epsilon) # Self attention. self.self_attention = GLMAttention(config) @@ -175,8 +177,7 @@ def __init__( # Layernorm on the attention output self.post_attention_layernorm = LayerNormFunc( - config.hidden_size, eps=config.layernorm_epsilon - ) + config.hidden_size, eps=config.layernorm_epsilon) # MLP self.mlp = GLMMLP(config) @@ -234,14 +235,14 @@ def __init__(self, config): self.num_layers = config.num_layers # Transformer layers. - self.layers = nn.ModuleList([GLMBlock(config) for i in range(self.num_layers)]) + self.layers = nn.ModuleList( + [GLMBlock(config) for i in range(self.num_layers)]) if self.post_layer_norm: LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm # Final layer norm before output. - self.final_layernorm = LayerNormFunc( - config.hidden_size, eps=config.layernorm_epsilon - ) + self.final_layernorm = LayerNormFunc(config.hidden_size, + eps=config.layernorm_epsilon) def forward( self, @@ -272,12 +273,12 @@ def forward( class ChatGLMModel(nn.Module): + def __init__(self, config): super().__init__() - self.embedding = VocabParallelEmbedding( - config.padded_vocab_size, config.hidden_size - ) + self.embedding = VocabParallelEmbedding(config.padded_vocab_size, + config.hidden_size) self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num @@ -315,6 +316,7 @@ def forward( class ChatGLMForCausalLM(nn.Module): + def __init__(self, config: ChatGLMConfig): super().__init__() self.config: ChatGLMConfig = config @@ -330,20 +332,17 @@ def forward( input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ) -> Dict[int, SequenceOutputs]: - hidden_states = self.transformer( - input_ids, positions, kv_caches, input_metadata, cache_events - ) - next_tokens = self.sampler(self.lm_head_weight, hidden_states, input_metadata) + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata, cache_events) + next_tokens = self.sampler(self.lm_head_weight, hidden_states, + input_metadata) return next_tokens _column_parallel_weights = [ "output_layer.weight", "embedding.weight", ] - _row_parallel_weights = [ - "dense_4h_to_h", - "self_attention.dense" - ] + _row_parallel_weights = ["dense_4h_to_h", "self_attention.dense"] def load_weights( self, @@ -355,7 +354,7 @@ def load_weights( tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() - q_proj_shard_size = (self.config.hidden_size // tp_size) + q_proj_shard_size = self.config.hidden_size // tp_size kv_proj_shard_size = (self.config.hidden_size // self.config.num_attention_heads * self.config.multi_query_group_num // tp_size) @@ -364,17 +363,18 @@ def load_weights( state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision - ): + model_name_or_path, cache_dir, load_format, revision): if "word_embeddings" in name: name = name.replace(".word_embeddings", "") - + if name in state_dict: param = state_dict[name] if "query_key_value" in name: q_offset = q_proj_shard_size * tp_rank - k_offset = q_proj_shard_size * tp_size + kv_proj_shard_size * tp_rank - v_offset = q_proj_shard_size * tp_size + kv_proj_shard_size * (tp_size + tp_rank) + k_offset = (q_proj_shard_size * tp_size + + kv_proj_shard_size * tp_rank) + v_offset = q_proj_shard_size * tp_size + \ + kv_proj_shard_size * (tp_size + tp_rank) wq = loaded_weight[q_offset:q_offset + q_proj_shard_size] wk = loaded_weight[k_offset:k_offset + kv_proj_shard_size] wv = loaded_weight[v_offset:v_offset + kv_proj_shard_size] @@ -383,8 +383,13 @@ def load_weights( continue if "dense_h_to_4h" in name: - w_gate = loaded_weight[mlp_hidden_shard_size * tp_rank:mlp_hidden_shard_size * (tp_rank + 1)] - w_proj = loaded_weight[mlp_hidden_shard_size * (tp_size + tp_rank):mlp_hidden_shard_size * (tp_size + tp_rank + 1)] + w_gate = loaded_weight[mlp_hidden_shard_size * + tp_rank:mlp_hidden_shard_size * + (tp_rank + 1)] + w_proj = loaded_weight[mlp_hidden_shard_size * + (tp_size + + tp_rank):mlp_hidden_shard_size * + (tp_size + tp_rank + 1)] loaded_weight = torch.cat([w_gate, w_proj], dim=0) param.data.copy_(loaded_weight) continue @@ -397,7 +402,7 @@ def load_weights( self._row_parallel_weights, tp_rank, ) - elif name == 'transformer.rotary_pos_emb.inv_freq': + elif name == "transformer.rotary_pos_emb.inv_freq": continue else: print("Warning never found tensor's name:", name) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 7bd54e9afd2a..4de8292f48b0 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -8,9 +8,5 @@ from vllm.transformers_utils.configs.chatglm import ChatGLMConfig __all__ = [ - "BaiChuanConfig", - "AquilaConfig", - "QWenConfig", - "RWConfig", - "ChatGLMConfig" + "BaiChuanConfig", "AquilaConfig", "QWenConfig", "RWConfig", "ChatGLMConfig" ] diff --git a/vllm/transformers_utils/configs/chatglm.py b/vllm/transformers_utils/configs/chatglm.py index 76faf5aa2292..63d6f0f5b9ce 100644 --- a/vllm/transformers_utils/configs/chatglm.py +++ b/vllm/transformers_utils/configs/chatglm.py @@ -50,7 +50,8 @@ def __init__(self, self.attention_dropout = attention_dropout self.layernorm_epsilon = layernorm_epsilon self.rmsnorm = rmsnorm - self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.apply_residual_connection_post_layernorm = ( + apply_residual_connection_post_layernorm) self.post_layer_norm = post_layer_norm self.add_bias_linear = add_bias_linear self.add_qkv_bias = add_qkv_bias From 251f70cd5b6a7687d109239b882fa7ccd1d348d4 Mon Sep 17 00:00:00 2001 From: zhinanertui Date: Sat, 4 Nov 2023 09:41:00 +0000 Subject: [PATCH 5/6] code formatting --- vllm/model_executor/__init__.py | 2 +- vllm/model_executor/models/chatglm.py | 20 ++++++++++---------- vllm/transformers_utils/configs/chatglm.py | 1 + 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index e65b151168d7..36fc30f9c1e3 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -6,4 +6,4 @@ "InputMetadata", "get_model", "set_random_seed", -] \ No newline at end of file +] diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index caea0f8c1ceb..39770685b442 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -118,7 +118,7 @@ class GLMMLP(nn.Module): """ def __init__(self, config): - super(GLMMLP, self).__init__() + super().__init__() self.add_bias = config.add_bias_linear @@ -160,23 +160,23 @@ def __init__( self, config, ): - super(GLMBlock, self).__init__() + super().__init__() self.apply_residual_connection_post_layernorm = ( config.apply_residual_connection_post_layernorm) self.fp32_residual_connection = config.fp32_residual_connection - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm # Layernorm on the input data. - self.input_layernorm = LayerNormFunc(config.hidden_size, - eps=config.layernorm_epsilon) + self.input_layernorm = layer_norm_func(config.hidden_size, + eps=config.layernorm_epsilon) # Self attention. self.self_attention = GLMAttention(config) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output - self.post_attention_layernorm = LayerNormFunc( + self.post_attention_layernorm = layer_norm_func( config.hidden_size, eps=config.layernorm_epsilon) # MLP @@ -228,7 +228,7 @@ class GLMTransformer(nn.Module): """Transformer class.""" def __init__(self, config): - super(GLMTransformer, self).__init__() + super().__init__() self.post_layer_norm = config.post_layer_norm # Number of layers. @@ -239,10 +239,10 @@ def __init__(self, config): [GLMBlock(config) for i in range(self.num_layers)]) if self.post_layer_norm: - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, - eps=config.layernorm_epsilon) + self.final_layernorm = layer_norm_func( + config.hidden_size, eps=config.layernorm_epsilon) def forward( self, diff --git a/vllm/transformers_utils/configs/chatglm.py b/vllm/transformers_utils/configs/chatglm.py index 63d6f0f5b9ce..c4244f8c77f4 100644 --- a/vllm/transformers_utils/configs/chatglm.py +++ b/vllm/transformers_utils/configs/chatglm.py @@ -64,4 +64,5 @@ def __init__(self, self.quantization_bit = quantization_bit self.pre_seq_len = pre_seq_len self.prefix_projection = prefix_projection + self.interleaved_qkv = interleaved_qkv super().__init__(**kwargs) From be19ad5af151a3a1ab8d9165cd6e8d60399f7c4d Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 7 Nov 2023 00:07:35 +0000 Subject: [PATCH 6/6] fix style --- vllm/model_executor/models/chatglm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 39770685b442..8acc8e468b65 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -373,8 +373,8 @@ def load_weights( q_offset = q_proj_shard_size * tp_rank k_offset = (q_proj_shard_size * tp_size + kv_proj_shard_size * tp_rank) - v_offset = q_proj_shard_size * tp_size + \ - kv_proj_shard_size * (tp_size + tp_rank) + v_offset = (q_proj_shard_size * tp_size + + kv_proj_shard_size * (tp_size + tp_rank)) wq = loaded_weight[q_offset:q_offset + q_proj_shard_size] wk = loaded_weight[k_offset:k_offset + kv_proj_shard_size] wv = loaded_weight[v_offset:v_offset + kv_proj_shard_size]