Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
code formatting
  • Loading branch information
GoHomeToMacDonal committed Nov 4, 2023
commit 251f70cd5b6a7687d109239b882fa7ccd1d348d4
2 changes: 1 addition & 1 deletion vllm/model_executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
"InputMetadata",
"get_model",
"set_random_seed",
]
]
20 changes: 10 additions & 10 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class GLMMLP(nn.Module):
"""

def __init__(self, config):
super(GLMMLP, self).__init__()
super().__init__()

self.add_bias = config.add_bias_linear

Expand Down Expand Up @@ -160,23 +160,23 @@ def __init__(
self,
config,
):
super(GLMBlock, self).__init__()
super().__init__()
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)

self.fp32_residual_connection = config.fp32_residual_connection

LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
self.input_layernorm = LayerNormFunc(config.hidden_size,
eps=config.layernorm_epsilon)
self.input_layernorm = layer_norm_func(config.hidden_size,
eps=config.layernorm_epsilon)

# Self attention.
self.self_attention = GLMAttention(config)
self.hidden_dropout = config.hidden_dropout

# Layernorm on the attention output
self.post_attention_layernorm = LayerNormFunc(
self.post_attention_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)

# MLP
Expand Down Expand Up @@ -228,7 +228,7 @@ class GLMTransformer(nn.Module):
"""Transformer class."""

def __init__(self, config):
super(GLMTransformer, self).__init__()
super().__init__()
self.post_layer_norm = config.post_layer_norm

# Number of layers.
Expand All @@ -239,10 +239,10 @@ def __init__(self, config):
[GLMBlock(config) for i in range(self.num_layers)])

if self.post_layer_norm:
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output.
self.final_layernorm = LayerNormFunc(config.hidden_size,
eps=config.layernorm_epsilon)
self.final_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)

def forward(
self,
Expand Down
1 change: 1 addition & 0 deletions vllm/transformers_utils/configs/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,5 @@ def __init__(self,
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
self.interleaved_qkv = interleaved_qkv
super().__init__(**kwargs)