Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/exllama_v2.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ python3 -m fastchat.serve.model_worker \
--exllama-gpu-split 18,24
```

`--exllama-cache-8bit` can be used to enable 8-bit caching with exllama and save some VRAM.

## Performance

Reference: https://github.com/turboderp/exllamav2#performance
Expand Down
5 changes: 5 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,11 @@ def add_model_args(parser):
default=None,
help="Used for exllamabv2. Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7",
)
parser.add_argument(
"--exllama-cache-8bit",
action="store_true",
help="Used for exllamabv2. Use 8-bit cache to save VRAM.",
)
parser.add_argument(
"--enable-xft",
action="store_true",
Expand Down
6 changes: 5 additions & 1 deletion fastchat/modules/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
class ExllamaConfig:
max_seq_len: int
gpu_split: str = None
cache_8bit: bool = False


class ExllamaModel:
Expand All @@ -22,6 +23,7 @@ def load_exllama_model(model_path, exllama_config: ExllamaConfig):
ExLlamaV2Tokenizer,
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
)
except ImportError as e:
print(f"Error: Failed to load Exllamav2. {e}")
Expand All @@ -31,6 +33,7 @@ def load_exllama_model(model_path, exllama_config: ExllamaConfig):
exllamav2_config.model_dir = model_path
exllamav2_config.prepare()
exllamav2_config.max_seq_len = exllama_config.max_seq_len
exllamav2_config.cache_8bit = exllama_config.cache_8bit

exllama_model = ExLlamaV2(exllamav2_config)
tokenizer = ExLlamaV2Tokenizer(exllamav2_config)
Expand All @@ -40,7 +43,8 @@ def load_exllama_model(model_path, exllama_config: ExllamaConfig):
split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")]
exllama_model.load(split)

exllama_cache = ExLlamaV2Cache(exllama_model)
cache_class = ExLlamaV2Cache_8bit if exllamav2_config.cache_8bit else ExLlamaV2Cache
exllama_cache = cache_class(exllama_model)
model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache)

return model, tokenizer
1 change: 1 addition & 0 deletions fastchat/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def main(args):
exllama_config = ExllamaConfig(
max_seq_len=args.exllama_max_seq_len,
gpu_split=args.exllama_gpu_split,
cache_8bit=args.exllama_cache_8bit,
)
else:
exllama_config = None
Expand Down
1 change: 1 addition & 0 deletions fastchat/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def create_model_worker():
exllama_config = ExllamaConfig(
max_seq_len=args.exllama_max_seq_len,
gpu_split=args.exllama_gpu_split,
cache_8bit=args.exllama_cache_8bit,
)
else:
exllama_config = None
Expand Down
1 change: 1 addition & 0 deletions fastchat/serve/multi_model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def create_multi_model_worker():
exllama_config = ExllamaConfig(
max_seq_len=args.exllama_max_seq_len,
gpu_split=args.exllama_gpu_split,
cache_8bit=args.exllama_cache_8bit,
)
else:
exllama_config = None
Expand Down