-
Notifications
You must be signed in to change notification settings - Fork 4.8k
move BaseModelWorker outside serve.model_worker to make it independent #2531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
7f0141c
move BaseModelWorker outside serve.model_worker to make it independent
liunux4odoo c0db198
- build a default logger in BaseModelWorker.__init__ to avoid duplicated
liunux4odoo 6919345
Merge branch 'main' into fix
liunux4odoo 9a5d87b
Organize the import statements as three blocks
liunux4odoo 2ff03bf
fix: worker is None
liunux4odoo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,240 @@ | ||
| import asyncio | ||
| import threading | ||
| import time | ||
| from typing import List | ||
| import uuid | ||
|
|
||
| from fastapi import FastAPI, Request, BackgroundTasks | ||
| from fastapi.responses import StreamingResponse, JSONResponse | ||
| import requests | ||
|
|
||
| from fastchat.constants import WORKER_HEART_BEAT_INTERVAL | ||
| from fastchat.conversation import Conversation | ||
| from fastchat.utils import pretty_print_semaphore, build_logger | ||
|
|
||
|
|
||
| worker_id = str(uuid.uuid4())[:8] | ||
| worker = None | ||
| logger = None | ||
|
|
||
| app = FastAPI() | ||
|
|
||
|
|
||
| def heart_beat_worker(obj): | ||
| while True: | ||
| time.sleep(WORKER_HEART_BEAT_INTERVAL) | ||
| obj.send_heart_beat() | ||
|
|
||
|
|
||
| class BaseModelWorker: | ||
| def __init__( | ||
| self, | ||
| controller_addr: str, | ||
| worker_addr: str, | ||
| worker_id: str, | ||
| model_path: str, | ||
| model_names: List[str], | ||
| limit_worker_concurrency: int, | ||
| conv_template: str = None, | ||
| ): | ||
| global logger, worker | ||
|
|
||
| self.controller_addr = controller_addr | ||
| self.worker_addr = worker_addr | ||
| self.worker_id = worker_id | ||
| if model_path.endswith("/"): | ||
| model_path = model_path[:-1] | ||
| self.model_names = model_names or [model_path.split("/")[-1]] | ||
| self.limit_worker_concurrency = limit_worker_concurrency | ||
| self.conv = self.make_conv_template(conv_template, model_path) | ||
| self.conv.sep_style = int(self.conv.sep_style) | ||
| self.tokenizer = None | ||
| self.context_len = None | ||
| self.call_ct = 0 | ||
| self.semaphore = None | ||
|
|
||
| self.heart_beat_thread = None | ||
|
|
||
| if logger is None: | ||
| logger = build_logger("model_worker", f"model_worker_{worker_id}.log") | ||
| if worker is None: | ||
| worker = self | ||
|
|
||
| def make_conv_template( | ||
| self, | ||
| conv_template: str = None, | ||
| model_path: str = None, | ||
| ) -> Conversation: | ||
| """ | ||
| can be overrided to costomize the conversation template for different model workers. | ||
| """ | ||
| from fastchat.conversation import get_conv_template | ||
| from fastchat.model.model_adapter import get_conversation_template | ||
|
|
||
| if conv_template: | ||
| conv = get_conv_template(conv_template) | ||
| else: | ||
| conv = get_conversation_template(model_path) | ||
| return conv | ||
|
|
||
| def init_heart_beat(self): | ||
| self.register_to_controller() | ||
| self.heart_beat_thread = threading.Thread( | ||
| target=heart_beat_worker, | ||
| args=(self,), | ||
| daemon=True, | ||
| ) | ||
| self.heart_beat_thread.start() | ||
|
|
||
| def register_to_controller(self): | ||
| logger.info("Register to controller") | ||
|
|
||
| url = self.controller_addr + "/register_worker" | ||
| data = { | ||
| "worker_name": self.worker_addr, | ||
| "check_heart_beat": True, | ||
| "worker_status": self.get_status(), | ||
| } | ||
| r = requests.post(url, json=data) | ||
| assert r.status_code == 200 | ||
|
|
||
| def send_heart_beat(self): | ||
| logger.info( | ||
| f"Send heart beat. Models: {self.model_names}. " | ||
| f"Semaphore: {pretty_print_semaphore(self.semaphore)}. " | ||
| f"call_ct: {self.call_ct}. " | ||
| f"worker_id: {self.worker_id}. " | ||
| ) | ||
|
|
||
| url = self.controller_addr + "/receive_heart_beat" | ||
|
|
||
| while True: | ||
| try: | ||
| ret = requests.post( | ||
| url, | ||
| json={ | ||
| "worker_name": self.worker_addr, | ||
| "queue_length": self.get_queue_length(), | ||
| }, | ||
| timeout=5, | ||
| ) | ||
| exist = ret.json()["exist"] | ||
| break | ||
| except (requests.exceptions.RequestException, KeyError) as e: | ||
| logger.error(f"heart beat error: {e}") | ||
| time.sleep(5) | ||
|
|
||
| if not exist: | ||
| self.register_to_controller() | ||
|
|
||
| def get_queue_length(self): | ||
| if ( | ||
| self.semaphore is None | ||
| or self.semaphore._value is None | ||
| or self.semaphore._waiters is None | ||
| ): | ||
| return 0 | ||
| else: | ||
| return ( | ||
| self.limit_worker_concurrency | ||
| - self.semaphore._value | ||
| + len(self.semaphore._waiters) | ||
| ) | ||
|
|
||
| def get_status(self): | ||
| return { | ||
| "model_names": self.model_names, | ||
| "speed": 1, | ||
| "queue_length": self.get_queue_length(), | ||
| } | ||
|
|
||
| def count_token(self, params): | ||
| prompt = params["prompt"] | ||
|
|
||
| try: | ||
| input_ids = self.tokenizer(prompt).input_ids | ||
| input_echo_len = len(input_ids) | ||
| except TypeError: | ||
| input_echo_len = self.tokenizer.num_tokens(prompt) | ||
|
|
||
| ret = { | ||
| "count": input_echo_len, | ||
| "error_code": 0, | ||
| } | ||
| return ret | ||
|
|
||
| def get_conv_template(self): | ||
| return {"conv": self.conv} | ||
|
|
||
| def generate_stream_gate(self, params): | ||
| raise NotImplementedError | ||
|
|
||
| def generate_gate(self, params): | ||
| raise NotImplementedError | ||
|
|
||
| def get_embeddings(self, params): | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| def release_worker_semaphore(): | ||
| worker.semaphore.release() | ||
|
|
||
|
|
||
| def acquire_worker_semaphore(): | ||
| if worker.semaphore is None: | ||
| worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) | ||
| return worker.semaphore.acquire() | ||
|
|
||
|
|
||
| def create_background_tasks(): | ||
| background_tasks = BackgroundTasks() | ||
| background_tasks.add_task(release_worker_semaphore) | ||
| return background_tasks | ||
|
|
||
|
|
||
| @app.post("/worker_generate_stream") | ||
| async def api_generate_stream(request: Request): | ||
| params = await request.json() | ||
| await acquire_worker_semaphore() | ||
| generator = worker.generate_stream_gate(params) | ||
| background_tasks = create_background_tasks() | ||
| return StreamingResponse(generator, background=background_tasks) | ||
|
|
||
|
|
||
| @app.post("/worker_generate") | ||
| async def api_generate(request: Request): | ||
| params = await request.json() | ||
| await acquire_worker_semaphore() | ||
| output = worker.generate_gate(params) | ||
| release_worker_semaphore() | ||
| return JSONResponse(output) | ||
|
|
||
|
|
||
| @app.post("/worker_get_embeddings") | ||
| async def api_get_embeddings(request: Request): | ||
| params = await request.json() | ||
| await acquire_worker_semaphore() | ||
| embedding = worker.get_embeddings(params) | ||
| release_worker_semaphore() | ||
| return JSONResponse(content=embedding) | ||
|
|
||
|
|
||
| @app.post("/worker_get_status") | ||
| async def api_get_status(request: Request): | ||
| return worker.get_status() | ||
|
|
||
|
|
||
| @app.post("/count_token") | ||
| async def api_count_token(request: Request): | ||
| params = await request.json() | ||
| return worker.count_token(params) | ||
|
|
||
|
|
||
| @app.post("/worker_get_conv_template") | ||
| async def api_get_conv(request: Request): | ||
| return worker.get_conv_template() | ||
|
|
||
|
|
||
| @app.post("/model_details") | ||
| async def api_model_details(request: Request): | ||
| return {"context_length": worker.context_len} | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.