Skip to content

Commit e54916c

Browse files
committed
Add tutorial
1 parent bbc9685 commit e54916c

File tree

4 files changed

+172
-0
lines changed

4 files changed

+172
-0
lines changed

README_CN.md

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# chat2db-sqlcoder-deploy部署
2+
3+
语言:中文 | [English](README.md)
4+
5+
## 📖 简介
6+
这个工程介绍了如何在阿里云上免费部署sqlcoder的8bit量化模型,并将大模型应用到Chat2DB客户端中。
7+
8+
!!!请注意,sqlcoder项目主要是针对SQL生成的,所以在自然语言转SQL方面表现较好,但是在SQL解释、SQL优化和SQL转化方面表现略差,仅供大家实验参考,切勿迁怒于模型或产品。
9+
10+
## 📦 硬件要求
11+
| 模型 | 最低GPU显存(推理) | 最低GPU显存(高效参数微调) |
12+
|:-------------:|:-----------:|:---------------:|
13+
| sqlcoder-int8 | 20GB | 20GB |
14+
15+
16+
## 📦 部署
17+
### 📦 在阿里云DSW中部署8bit模型
18+
1. [阿里云免费使用平台](https://free.aliyun.com/)申请DSW免费试用。
19+
<img src="https://alidocs.oss-cn-zhangjiakou.aliyuncs.com/res/4j6OJdYA60Y7n3p8/img/e5141c12-0279-451b-9e47-5125a5a34731.png?x-oss-process=image/resize,w_1280,m_lfit,limit_1">
20+
2. 创建一个DSW实例,资源组选择可以抵扣资源包的资源组,实例镜像选择pytorch:1.12-gpu-py39-cu113-ubuntu20.04
21+
<img src="https://alidocs.oss-cn-zhangjiakou.aliyuncs.com/res/4j6OJdYA60Y7n3p8/img/d5ed7234-afb3-49de-a2a2-db6aa0424efa.png?x-oss-process=image/resize,w_1280,m_lfit,limit_1">
22+
<img src="https://alidocs.oss-cn-zhangjiakou.aliyuncs.com/res/4j6OJdYA60Y7n3p8/img/26c3961f-967d-4b11-8a81-4b037c833344.png?x-oss-process=image/resize,w_1280,m_lfit,limit_1">
23+
3. 安装本仓库中的[requirements.txt](requirements.txt)中的依赖包
24+
```bash
25+
pip install -r requirements.txt
26+
```
27+
4. 因为要跑8bit的量化模型,所以还需要下载bitsandbytes包,执行下面的命令下载最新版本,否则cuda有可能会出现不兼容的情况
28+
```bash
29+
pip install -i https://test.pypi.org/simple/ bitsandbytes
30+
```
31+
5. 在DSW实例中打开一个terminal,创建sqlcoder-model和sqlcoder文件夹
32+
6. 在sqlcoder-model文件夹下下载sqlcoder模型,执行下面的命令,请确保模型里面的几个bin文件下载完整且正确
33+
```bash
34+
git clone https://huggingface.co/defog/sqlcoder
35+
```
36+
7. 将本项目下的api.py和prompt.md文件拷贝到sqlcoder文件夹下
37+
8. 安装fastapi相关包
38+
```bash
39+
pip install fastapi nest-asyncio pyngrok uvicorn
40+
```
41+
9. 在sqlcoder文件夹下执行下面的命令,启动api服务
42+
```bash
43+
python api.py
44+
```
45+
10. 执行以上步骤之后,你将得到一个api url,类似于`https://dfb1-34-87-2-137.ngrok.io`
46+
<img src="https://alidocs.oss-cn-zhangjiakou.aliyuncs.com/res/4j6OJdYA60Y7n3p8/img/086b2121-19d3-4bff-a188-91e51d0c208d.png?x-oss-process=image/resize,w_1280,m_lfit,limit_1">
47+
48+
11. 将api url复制到chat2db客户端中,即可开始使用模型生成SQL了。参考下图进行配置
49+
<img src="https://alidocs.oss-cn-zhangjiakou.aliyuncs.com/res/4j6OJdYA60Y7n3p8/img/ca844185-2744-49e0-ab75-245e19b872d6.png?x-oss-process=image/resize,w_640,m_lfit,limit_1">
50+
51+
- 实验结果如下
52+
<img src="https://alidocs.oss-cn-zhangjiakou.aliyuncs.com/res/4j6OJdYA60Y7n3p8/img/d3f319f6-2612-4352-ab46-99ff92dace63.png?x-oss-process=image/resize,w_1280,m_lfit,limit_1">
53+
* 注意: 模型推理时间可能会比较长,会有明显的卡顿。
54+
55+
### 📦 在阿里云DSW中部署非量化模型
56+
* 如果机器资源允许,可以尝试部署非量化的sqlcoder模型,在生成SQL的准确率上会比8bit的模型高一些,但是需要更多的显存和更长的推理时间。
57+
* 部署非量化模型的步骤同上,只需要将api.py文件中的模型加载改成float16的模型即可,具体如下:
58+
```python
59+
model = AutoModelForCausalLM.from_pretrained("/mnt/workspace/sqlcoder-model/sqlcoder",
60+
trust_remote_code=True,
61+
torch_dtype=torch.float16,
62+
# load_in_8bit=True,
63+
device_map="auto",
64+
use_cache=True)
65+
```
66+
67+
### 📦 在其他云资源上部署sqlcoder模型
68+
* 本教程虽然写的是在阿里云DSW环境上完成的,但是本教程中的脚本和命令并没有进行任何定制,理论上遵循以上步骤,可以在任何云资源上进行部署。

api.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import torch
2+
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline
3+
import argparse
4+
from fastapi import FastAPI, Request
5+
import uvicorn, json, datetime
6+
import nest_asyncio
7+
from pyngrok import ngrok
8+
9+
DEVICE = "cuda"
10+
DEVICE_ID = "0"
11+
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
12+
13+
14+
def torch_gc():
15+
if torch.cuda.is_available():
16+
with torch.cuda.device(CUDA_DEVICE):
17+
torch.cuda.empty_cache()
18+
torch.cuda.ipc_collect()
19+
20+
21+
app = FastAPI()
22+
23+
24+
@app.post("/")
25+
async def create_item(request: Request):
26+
global model, tokenizer, prompt_template
27+
json_post_raw = await request.json()
28+
json_post = json.dumps(json_post_raw)
29+
json_post_list = json.loads(json_post)
30+
question = json_post_list.get('prompt')
31+
prompt = prompt_template.format(
32+
user_question=question.replace("#","")
33+
)
34+
sql_type = "自然语言转换成SQL查询"
35+
if sql_type in prompt:
36+
prompt += "```sql"
37+
else:
38+
prompt += ">>>"
39+
history = json_post_list.get('history')
40+
max_length = json_post_list.get('max_length')
41+
top_p = json_post_list.get('top_p')
42+
temperature = json_post_list.get('temperature')
43+
eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]
44+
print("Loading a model and generating a SQL query for answering your question...")
45+
pipe = pipeline(
46+
"text-generation",
47+
model=model,
48+
tokenizer=tokenizer,
49+
max_new_tokens=300,
50+
do_sample=False,
51+
num_beams=5, # do beam search with 5 beams for high quality results
52+
)
53+
print("==========input========")
54+
print(prompt)
55+
generated_query = (
56+
pipe(
57+
prompt,
58+
num_return_sequences=1,
59+
eos_token_id=eos_token_id,
60+
pad_token_id=eos_token_id,
61+
)[0]["generated_text"]
62+
)
63+
64+
response = generated_query
65+
66+
if sql_type in prompt:
67+
response = response.split("`sql")[-1].split("`")[0].split(";")[0].strip() + ";"
68+
69+
else:
70+
response = response.split(">>>")[-1].split("`")[0].strip()
71+
72+
print("========output========")
73+
print(response)
74+
torch_gc()
75+
return response
76+
77+
78+
if __name__ == '__main__':
79+
prompt_template = ""
80+
with open("prompt.md", "r") as f:
81+
prompt_template = f.read()
82+
tokenizer = AutoTokenizer.from_pretrained("/mnt/workspace/sqlcoder-model/sqlcoder", trust_remote_code=True)
83+
model = AutoModelForCausalLM.from_pretrained("/mnt/workspace/sqlcoder-model/sqlcoder",
84+
trust_remote_code=True,
85+
# torch_dtype=torch.float16,
86+
load_in_8bit=True,
87+
device_map="auto",
88+
use_cache=True)
89+
ngrok_tunnel = ngrok.connect(8000)
90+
print('Public URL:', ngrok_tunnel.public_url)
91+
nest_asyncio.apply()
92+
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
93+
94+
95+

prompt.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
### Instructions:
2+
{user_question}
3+
### Response:
4+
Based on your instructions, here is the result I have generated to answer the question `{user_question}`:

requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
tqdm==4.65.0
2+
transformers==4.28.1
3+
datasets==2.11.0
4+
huggingface-hub==0.13.4
5+
accelerate==0.18.0

0 commit comments

Comments
 (0)