diff --git a/fastchat/serve/model_worker.py b/fastchat/serve/model_worker.py index 5e84a4262..93ccaa54a 100644 --- a/fastchat/serve/model_worker.py +++ b/fastchat/serve/model_worker.py @@ -31,7 +31,6 @@ str_to_torch_dtype, ) - worker_id = str(uuid.uuid4())[:8] logger = build_logger("model_worker", f"model_worker_{worker_id}.log") @@ -101,6 +100,10 @@ def __init__( self.init_heart_beat() def generate_stream_gate(self, params): + if self.device == "npu": + import torch_npu + + torch_npu.npu.set_device("npu:0") self.call_ct += 1 try: