diff --git a/docs/mlx_integration.md b/docs/mlx_integration.md new file mode 100644 index 000000000..b500207a6 --- /dev/null +++ b/docs/mlx_integration.md @@ -0,0 +1,23 @@ +# Apple MLX Integration + +You can use [Apple MLX](https://github.com/ml-explore/mlx) as an optimized worker implementation in FastChat. + +It runs models efficiently on Apple Silicon + +See the supported models [here](https://github.com/ml-explore/mlx-examples/tree/main/llms#supported-models). + +Note that for Apple Silicon Macs with less memory, smaller models (or quantized models) are recommended. + +## Instructions + +1. Install MLX. + + ``` + pip install mlx-lm + ``` + +2. When you launch a model worker, replace the normal worker (`fastchat.serve.model_worker`) with the MLX worker (`fastchat.serve.mlx_worker`). + + ``` + python3 -m fastchat.serve.mlx_worker --model-path microsoft/phi-2 + ``` diff --git a/fastchat/serve/mlx_worker.py b/fastchat/serve/mlx_worker.py new file mode 100644 index 000000000..bff5820ad --- /dev/null +++ b/fastchat/serve/mlx_worker.py @@ -0,0 +1,286 @@ +""" +A model worker using Apple MLX + +docs/mlx_integration.md + +https://github.com/ml-explore/mlx-examples/tree/main/llms + +Code based on vllm_worker https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/vllm_worker.py + +You must install MLX python: +pip install mlx-lm +""" + +import argparse +import asyncio +import atexit +import json +from typing import List +import uuid + +from fastapi import FastAPI, Request, BackgroundTasks +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn + +from fastchat.serve.base_model_worker import BaseModelWorker +from fastchat.serve.model_worker import ( + logger, + worker_id, +) +from fastchat.utils import get_context_length, is_partial_stop + +import mlx.core as mx +from mlx_lm import load, generate +from mlx_lm.utils import generate_step + +app = FastAPI() + + +class MLXWorker(BaseModelWorker): + def __init__( + self, + controller_addr: str, + worker_addr: str, + worker_id: str, + model_path: str, + model_names: List[str], + limit_worker_concurrency: int, + no_register: bool, + llm_engine: "MLX", + conv_template: str, + ): + super().__init__( + controller_addr, + worker_addr, + worker_id, + model_path, + model_names, + limit_worker_concurrency, + conv_template, + ) + + logger.info( + f"Loading the model {self.model_names} on worker {worker_id}, worker type: MLX worker..." + ) + + self.model_name = model_path + self.mlx_model, self.mlx_tokenizer = load(model_path) + + self.tokenizer = self.mlx_tokenizer + # self.context_len = get_context_length( + # llm_engine.engine.model_config.hf_config) + self.context_len = 2048 # hard code for now -- not sure how to get in MLX + + if not no_register: + self.init_heart_beat() + + async def generate_stream(self, params): + self.call_ct += 1 + + context = params.pop("prompt") + request_id = params.pop("request_id") + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", -1.0) + presence_penalty = float(params.get("presence_penalty", 0.0)) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) + max_new_tokens = params.get("max_new_tokens", 256) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + if self.tokenizer.eos_token_id is not None: + stop_token_ids.append(self.tokenizer.eos_token_id) + echo = params.get("echo", True) + use_beam_search = params.get("use_beam_search", False) + best_of = params.get("best_of", None) + + # Handle stop_str + stop = set() + if isinstance(stop_str, str) and stop_str != "": + stop.add(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.update(stop_str) + + for tid in stop_token_ids: + if tid is not None: + s = self.tokenizer.decode(tid) + if s != "": + stop.add(s) + + print("Stop patterns: ", stop) + + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + + tokens = [] + skip = 0 + + context_mlx = mx.array(self.tokenizer.encode(context)) + + finish_reason = "length" + + for token, _ in zip( + generate_step(context_mlx, self.mlx_model, temperature), + range(max_new_tokens), + ): + if token == self.mlx_tokenizer.eos_token_id: + finish_reason = "stop" + break + tokens.append(token.item()) + tokens_decoded = self.mlx_tokenizer.decode(tokens) + last_token_decoded = self.mlx_tokenizer.decode([token.item()]) + skip = len(tokens_decoded) + + partial_stop = any(is_partial_stop(tokens_decoded, i) for i in stop) + + if partial_stop: + finish_reason = "stop" + break + + ret = { + "text": tokens_decoded, + "error_code": 0, + "usage": { + "prompt_tokens": len(context), + "completion_tokens": len(tokens), + "total_tokens": len(context) + len(tokens), + }, + "cumulative_logprob": [], + "finish_reason": None, # hard code for now + } + # print(ret) + yield (json.dumps(ret) + "\0").encode() + ret = { + "text": self.mlx_tokenizer.decode(tokens), + "error_code": 0, + "usage": {}, + "cumulative_logprob": [], + "finish_reason": finish_reason, + } + yield (json.dumps(obj={**ret, **{"finish_reason": None}}) + "\0").encode() + yield (json.dumps(ret) + "\0").encode() + + async def generate(self, params): + async for x in self.generate_stream(params): + pass + return json.loads(x[:-1].decode()) + + +def release_worker_semaphore(): + worker.semaphore.release() + + +def acquire_worker_semaphore(): + if worker.semaphore is None: + worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency) + return worker.semaphore.acquire() + + +def create_background_tasks(request_id): + async def abort_request() -> None: + print("trying to abort but not implemented") + + background_tasks = BackgroundTasks() + background_tasks.add_task(release_worker_semaphore) + background_tasks.add_task(abort_request) + return background_tasks + + +@app.post("/worker_generate_stream") +async def api_generate_stream(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = uuid.uuid4() + params["request_id"] = str(request_id) + generator = worker.generate_stream(params) + background_tasks = create_background_tasks(request_id) + return StreamingResponse(generator, background=background_tasks) + + +@app.post("/worker_generate") +async def api_generate(request: Request): + params = await request.json() + await acquire_worker_semaphore() + request_id = uuid.uuid4() + params["request_id"] = str(request_id) + output = await worker.generate(params) + release_worker_semaphore() + # await engine.abort(request_id) + print("Trying to abort but not implemented") + return JSONResponse(output) + + +@app.post("/worker_get_status") +async def api_get_status(request: Request): + return worker.get_status() + + +@app.post("/count_token") +async def api_count_token(request: Request): + params = await request.json() + return worker.count_token(params) + + +@app.post("/worker_get_conv_template") +async def api_get_conv(request: Request): + return worker.get_conv_template() + + +@app.post("/model_details") +async def api_model_details(request: Request): + return {"context_length": worker.context_len} + + +worker = None + + +def cleanup_at_exit(): + global worker + print("Cleaning up...") + del worker + + +atexit.register(cleanup_at_exit) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=21002) + parser.add_argument("--worker-address", type=str, default="http://localhost:21002") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument("--model-path", type=str, default="microsoft/phi-2") + parser.add_argument( + "--model-names", + type=lambda s: s.split(","), + help="Optional display comma separated names", + ) + parser.add_argument( + "--conv-template", type=str, default=None, help="Conversation prompt template." + ) + parser.add_argument( + "--trust_remote_code", + action="store_false", + default=True, + help="Trust remote code (e.g., from HuggingFace) when" + "downloading the model and tokenizer.", + ) + + args, unknown = parser.parse_known_args() + + if args.model_path: + args.model = args.model_path + + worker = MLXWorker( + args.controller_address, + args.worker_address, + worker_id, + args.model_path, + args.model_names, + 1024, + False, + "MLX", + args.conv_template, + ) + uvicorn.run(app, host=args.host, port=args.port, log_level="info")