Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
多显卡模型加载,使用utf-8进行jsonl编码
  • Loading branch information
AreChen committed Jul 28, 2023
commit b3bb392bcbf3e18ca3f4db080a015d6d79b93d58
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,24 @@ python ./demo/run_demo.py
```python
model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True).half().cuda()
```
* 如果需要使用多显卡加载模型,可以将以下代码:
```python
tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True, device='cuda')
model = model.eval()
```
替换为
```python
def get_model():
tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
from utils import load_model_on_gpus
model = load_model_on_gpus("THUDM/codegeex2-6b", num_gpus=2)
model = model.eval()
return tokenizer, model

tokenizer, model = get_model()
```


## 代码能力评测

Expand Down
15 changes: 11 additions & 4 deletions demo/run_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@

from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True).to('cuda:0')
model = model.eval()
def get_model():
tokenizer = AutoTokenizer.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/codegeex2-6b", trust_remote_code=True).to('cuda:0')
# 如需实现多显卡模型加载,请将上面一行注释并启用一下两行,"num_gpus"调整为自己需求的显卡数量
# from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/codegeex2-6b", num_gpus=2)
model = model.eval()
return tokenizer, model

tokenizer, model = get_model()

examples = []
with open(os.path.join(os.path.split(os.path.realpath(__file__))[0], "example_inputs.jsonl"), "r") as f:
with open(os.path.join(os.path.split(os.path.realpath(__file__))[0], "example_inputs.jsonl"), "r", encoding="utf-8") as f:
for line in f:
examples.append(list(json.loads(line).values()))

Expand Down