-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Huggingface api worker #2456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Huggingface api worker #2456
Changes from 12 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
4f6e102
copy the model_worker.py
hnyls2002 2fd1f7a
the basic structure
hnyls2002 25b1a20
fix some format bugs
hnyls2002 104400b
can roughly be used...
hnyls2002 5e2e2dc
change the default name for falcon
hnyls2002 f23c824
fixed the params["stop"]'s type bugs
hnyls2002 931e2b8
Merge branch 'main' into huggingface-api-worker
hnyls2002 6f41a86
change the literal condition for falcon-180b-chat
hnyls2002 f1c8137
removed all unused imports
hnyls2002 cd329a0
remove the unused members and args
hnyls2002 6a7f4ba
rename the arg's name
hnyls2002 3aa01c9
support multiple huggingface models
hnyls2002 523f205
Merge branch 'lm-sys:main' into huggingface-api-worker
hnyls2002 cc6b7ac
move all model-related arguments into --model-info-file
hnyls2002 331133b
organize the imports using isort
hnyls2002 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,395 @@ | ||
| """ | ||
| A model worker to call huggingface api. | ||
| JSON file format: | ||
| { | ||
| "falcon-180b-chat": { | ||
| "model_path": "tiiuae/falcon-180B-chat", | ||
| "api_base": "https://api-inference.huggingface.co/models", | ||
| "token": "hf_xxx", | ||
| "context_length": 2048 | ||
| } | ||
| } | ||
| """ | ||
| import argparse | ||
| import asyncio | ||
| import json | ||
| from typing import List, Optional | ||
| import uuid | ||
|
|
||
| from fastapi import FastAPI, Request, BackgroundTasks | ||
| from fastapi.responses import StreamingResponse, JSONResponse | ||
| from huggingface_hub import InferenceClient | ||
| import requests | ||
| from fastchat.serve.model_worker import BaseModelWorker | ||
|
|
||
| import uvicorn | ||
|
|
||
| from fastchat.constants import ErrorCode, SERVER_ERROR_MSG | ||
| from fastchat.utils import build_logger | ||
|
|
||
|
|
||
| worker_id = str(uuid.uuid4())[:8] | ||
| logger = build_logger("model_worker", f"model_worker_{worker_id}.log") | ||
|
|
||
| workers = [] | ||
| worker_map = {} | ||
| app = FastAPI() | ||
|
|
||
|
|
||
| # reference to | ||
| # https://github.com/philschmid/easyllm/blob/cbd908b3b3f44a97a22cb0fc2c93df3660bacdad/easyllm/clients/huggingface.py#L374-L392 | ||
| def get_gen_kwargs( | ||
| params, | ||
| seed: Optional[int] = None, | ||
| ): | ||
| stop = params.get("stop", None) | ||
| if isinstance(stop, list): | ||
| stop_sequences = stop | ||
| elif isinstance(stop, str): | ||
| stop_sequences = [stop] | ||
| else: | ||
| stop_sequences = [] | ||
| gen_kwargs = { | ||
| "do_sample": True, | ||
| "return_full_text": bool(params.get("echo", False)), | ||
| "max_new_tokens": int(params.get("max_new_tokens", 256)), | ||
| "top_p": float(params.get("top_p", 1.0)), | ||
| "temperature": float(params.get("temperature", 1.0)), | ||
| "stop_sequences": stop_sequences, | ||
| "repetition_penalty": float(params.get("repetition_penalty", 1.0)), | ||
| "top_k": params.get("top_k", None), | ||
| "seed": seed, | ||
| } | ||
| if gen_kwargs["top_p"] == 1: | ||
| gen_kwargs["top_p"] = 0.9999999 | ||
| if gen_kwargs["top_p"] == 0: | ||
| gen_kwargs.pop("top_p") | ||
| if gen_kwargs["temperature"] == 0: | ||
| gen_kwargs.pop("temperature") | ||
| gen_kwargs["do_sample"] = False | ||
| return gen_kwargs | ||
|
|
||
|
|
||
| def could_be_stop(text, stop): | ||
| for s in stop: | ||
| if any(text.endswith(s[:i]) for i in range(1, len(s) + 1)): | ||
| return True | ||
| return False | ||
|
|
||
|
|
||
| class HuggingfaceApiWorker(BaseModelWorker): | ||
| def __init__( | ||
| self, | ||
| controller_addr: str, | ||
| worker_addr: str, | ||
| worker_id: str, | ||
| model_path: str, | ||
| api_base: str, | ||
| token: str, | ||
| context_length: int, | ||
| model_names: List[str], | ||
| limit_worker_concurrency: int, | ||
| no_register: bool, | ||
| conv_template: Optional[str] = None, | ||
| seed: Optional[int] = None, | ||
| **kwargs, | ||
| ): | ||
| super().__init__( | ||
| controller_addr, | ||
| worker_addr, | ||
| worker_id, | ||
| model_path, | ||
| model_names, | ||
| limit_worker_concurrency, | ||
| conv_template=conv_template, | ||
| ) | ||
|
|
||
| self.model_path = model_path | ||
| self.api_base = api_base | ||
| self.token = token | ||
| self.context_len = context_length | ||
| self.seed = seed | ||
|
|
||
| logger.info( | ||
| f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..." | ||
| ) | ||
|
|
||
| def count_token(self, params): | ||
| # No tokenizer here | ||
| ret = { | ||
| "count": 0, | ||
| "error_code": 0, | ||
| } | ||
| return ret | ||
|
|
||
| def generate_stream_gate(self, params): | ||
| self.call_ct += 1 | ||
|
|
||
| prompt = params["prompt"] | ||
| gen_kwargs = get_gen_kwargs(params, seed=self.seed) | ||
| stop = gen_kwargs["stop_sequences"] | ||
| if "falcon" in self.model_path and "chat" in self.model_path: | ||
| stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"]) | ||
| stop = list(set(stop)) | ||
| gen_kwargs["stop_sequences"] = stop | ||
|
|
||
| logger.info(f"prompt: {prompt}") | ||
| logger.info(f"gen_kwargs: {gen_kwargs}") | ||
|
|
||
| try: | ||
| url = f"{self.api_base}/{self.model_path}" | ||
| client = InferenceClient(url, token=self.token) | ||
| res = client.text_generation( | ||
| prompt, stream=True, details=True, **gen_kwargs | ||
| ) | ||
|
|
||
| reason = None | ||
| text = "" | ||
| for chunk in res: | ||
| if chunk.token.special: | ||
| continue | ||
| text += chunk.token.text | ||
|
|
||
| s = next((x for x in stop if text.endswith(x)), None) | ||
| if s is not None: | ||
| text = text[: -len(s)] | ||
| reason = "stop" | ||
| break | ||
| if could_be_stop(text, stop): | ||
| continue | ||
| if ( | ||
| chunk.details is not None | ||
| and chunk.details.finish_reason is not None | ||
| ): | ||
| reason = chunk.details.finish_reason | ||
| if reason not in ["stop", "length"]: | ||
| reason = None | ||
| ret = { | ||
| "text": text, | ||
| "error_code": 0, | ||
| "finish_reason": reason, | ||
| } | ||
| yield json.dumps(ret).encode() + b"\0" | ||
| except Exception as e: | ||
| ret = { | ||
| "text": f"{SERVER_ERROR_MSG}\n\n({e})", | ||
| "error_code": ErrorCode.INTERNAL_ERROR, | ||
| } | ||
| yield json.dumps(ret).encode() + b"\0" | ||
|
|
||
| def generate_gate(self, params): | ||
| for x in self.generate_stream_gate(params): | ||
| pass | ||
| return json.loads(x[:-1].decode()) | ||
|
|
||
| def get_embeddings(self, params): | ||
| raise NotImplementedError() | ||
|
|
||
|
|
||
| def release_worker_semaphore(worker): | ||
| worker.semaphore.release() | ||
|
|
||
|
|
||
| def acquire_worker_semaphore(worker): | ||
| if worker.semaphore is None: | ||
| worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) | ||
| return worker.semaphore.acquire() | ||
|
|
||
|
|
||
| def create_background_tasks(worker): | ||
| background_tasks = BackgroundTasks() | ||
| background_tasks.add_task(lambda: release_worker_semaphore(worker)) | ||
| return background_tasks | ||
|
|
||
|
|
||
| @app.post("/worker_generate_stream") | ||
| async def api_generate_stream(request: Request): | ||
| params = await request.json() | ||
| worker = worker_map[params["model"]] | ||
| await acquire_worker_semaphore(worker) | ||
| generator = worker.generate_stream_gate(params) | ||
| background_tasks = create_background_tasks(worker) | ||
| return StreamingResponse(generator, background=background_tasks) | ||
|
|
||
|
|
||
| @app.post("/worker_generate") | ||
| async def api_generate(request: Request): | ||
| params = await request.json() | ||
| worker = worker_map[params["model"]] | ||
| await acquire_worker_semaphore(worker) | ||
| output = worker.generate_gate(params) | ||
| release_worker_semaphore(worker) | ||
| return JSONResponse(output) | ||
|
|
||
|
|
||
| @app.post("/worker_get_embeddings") | ||
hnyls2002 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| async def api_get_embeddings(request: Request): | ||
| params = await request.json() | ||
| worker = worker_map[params["model"]] | ||
| await acquire_worker_semaphore(worker) | ||
| embedding = worker.get_embeddings(params) | ||
| release_worker_semaphore(worker) | ||
| return JSONResponse(content=embedding) | ||
|
|
||
|
|
||
| @app.post("/worker_get_status") | ||
| async def api_get_status(request: Request): | ||
| return { | ||
| "model_names": [m for w in workers for m in w.model_names], | ||
| "speed": 1, | ||
| "queue_length": sum([w.get_queue_length() for w in workers]), | ||
| } | ||
|
|
||
|
|
||
| @app.post("/count_token") | ||
| async def api_count_token(request: Request): | ||
| params = await request.json() | ||
| worker = worker_map[params["model"]] | ||
| return worker.count_token(params) | ||
|
|
||
|
|
||
| @app.post("/worker_get_conv_template") | ||
| async def api_get_conv(request: Request): | ||
| params = await request.json() | ||
| worker = worker_map[params["model"]] | ||
| return worker.get_conv_template() | ||
|
|
||
|
|
||
| @app.post("/model_details") | ||
| async def api_model_details(request: Request): | ||
| params = await request.json() | ||
| worker = worker_map[params["model"]] | ||
| return {"context_length": worker.context_len} | ||
|
|
||
|
|
||
| def create_huggingface_api_worker(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--host", type=str, default="localhost") | ||
| parser.add_argument("--port", type=int, default=21002) | ||
| parser.add_argument("--worker-address", type=str, default="http://localhost:21002") | ||
| parser.add_argument( | ||
| "--controller-address", type=str, default="http://localhost:21001" | ||
| ) | ||
| parser.add_argument( | ||
| "--model-info-file", | ||
| type=str, | ||
| required=True, | ||
| help="Huggingface API model's info file path", | ||
| ) | ||
|
|
||
| # support multi huggingface api models here | ||
| parser.add_argument( | ||
| "--model", | ||
| type=str, | ||
| default=[], | ||
| action="append", | ||
| help="The models' names to be called.", | ||
| ) | ||
| parser.add_argument( | ||
| "--model-names", | ||
| type=lambda s: s.split(","), | ||
| action="append", | ||
| help="One or more model names. Values must be aligned with `--model` values.", | ||
| ) | ||
hnyls2002 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| parser.add_argument( | ||
| "--conv-template", | ||
| type=str, | ||
| default=None, | ||
| action="append", | ||
| help="Conversation prompt template. Values must be aligned with `--model` values. If only one value is provided, it will be repeated for all models.", | ||
| ) | ||
| parser.add_argument( | ||
| "--limit-worker-concurrency", | ||
| type=int, | ||
| default=5, | ||
| help="Limit the model concurrency to prevent OOM.", | ||
| ) | ||
| parser.add_argument("--no-register", action="store_true") | ||
| parser.add_argument( | ||
| "--seed", | ||
| type=int, | ||
| default=None, | ||
| help="Overwrite the random seed for each generation.", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| if args.model_names is None: | ||
| args.model_names = [[x.split("/")[-1]] for x in args.model] | ||
| if args.conv_template is None: | ||
| args.conv_template = [None] * len(args.model) | ||
| elif len(args.conv_template) == 1: # Repeat the same template | ||
| args.conv_template = args.conv_template * len(args.model) | ||
|
|
||
| with open(args.model_info_file, "r", encoding="UTF-8") as f: | ||
| model_info = json.load(f) | ||
|
|
||
| logger.info(f"args: {args}") | ||
|
|
||
| model_path_list = [] | ||
| api_base_list = [] | ||
| token_list = [] | ||
| context_length_list = [] | ||
|
|
||
| for m in args.model: | ||
| if m not in model_info: | ||
| raise ValueError( | ||
| f"Model {args.model} not supported. Please add it to {args.model_info_file}." | ||
| ) | ||
| model_path_list.append(model_info[m]["model_path"]) | ||
| api_base_list.append(model_info[m]["api_base"]) | ||
| token_list.append(model_info[m]["token"]) | ||
| context_length_list.append(model_info[m]["context_length"]) | ||
|
|
||
| for ( | ||
| model_names, | ||
| conv_template, | ||
| model_path, | ||
| api_base, | ||
| token, | ||
| context_length, | ||
| ) in zip( | ||
| args.model_names, | ||
| args.conv_template, | ||
| model_path_list, | ||
| api_base_list, | ||
| token_list, | ||
| context_length_list, | ||
| ): | ||
| m = HuggingfaceApiWorker( | ||
| args.controller_address, | ||
| args.worker_address, | ||
| worker_id, | ||
| model_path, | ||
| api_base, | ||
| token, | ||
| context_length, | ||
| model_names, | ||
| args.limit_worker_concurrency, | ||
| no_register=args.no_register, | ||
| conv_template=conv_template, | ||
| seed=args.seed, | ||
| ) | ||
| workers.append(m) | ||
| for name in model_names: | ||
| worker_map[name] = m | ||
|
|
||
| # register all the models | ||
| url = args.controller_address + "/register_worker" | ||
| data = { | ||
| "worker_name": workers[0].worker_addr, | ||
| "check_heart_beat": not args.no_register, | ||
| "worker_status": { | ||
| "model_names": [m for w in workers for m in w.model_names], | ||
| "speed": 1, | ||
| "queue_length": sum([w.get_queue_length() for w in workers]), | ||
| }, | ||
| } | ||
| r = requests.post(url, json=data) | ||
| assert r.status_code == 200 | ||
|
|
||
| return args, workers | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args, workers = create_huggingface_api_worker() | ||
| uvicorn.run(app, host=args.host, port=args.port, log_level="info") | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.