22
33import os
44import json
5+ import typing
56import contextlib
67
78from threading import Lock
@@ -88,11 +89,12 @@ def get_llama_proxy():
8889 llama_outer_lock .release ()
8990
9091
91- _ping_message_factory = None
92+ _ping_message_factory : typing . Optional [ typing . Callable [[], bytes ]] = None
9293
93- def set_ping_message_factory (factory ):
94- global _ping_message_factory
95- _ping_message_factory = factory
94+
95+ def set_ping_message_factory (factory : typing .Callable [[], bytes ]):
96+ global _ping_message_factory
97+ _ping_message_factory = factory
9698
9799
98100def create_app (
@@ -155,20 +157,19 @@ def create_app(
155157
156158async def get_event_publisher (
157159 request : Request ,
158- inner_send_chan : MemoryObjectSendStream ,
159- iterator : Iterator ,
160- on_complete = None ,
160+ inner_send_chan : MemoryObjectSendStream [ typing . Any ] ,
161+ iterator : Iterator [ typing . Any ] ,
162+ on_complete : typing . Optional [ typing . Callable [[], None ]] = None ,
161163):
164+ server_settings = next (get_server_settings ())
165+ interrupt_requests = server_settings .interrupt_requests if server_settings else False
162166 async with inner_send_chan :
163167 try :
164168 async for chunk in iterate_in_threadpool (iterator ):
165169 await inner_send_chan .send (dict (data = json .dumps (chunk )))
166170 if await request .is_disconnected ():
167171 raise anyio .get_cancelled_exc_class ()()
168- if (
169- next (get_server_settings ()).interrupt_requests
170- and llama_outer_lock .locked ()
171- ):
172+ if interrupt_requests and llama_outer_lock .locked ():
172173 await inner_send_chan .send (dict (data = "[DONE]" ))
173174 raise anyio .get_cancelled_exc_class ()()
174175 await inner_send_chan .send (dict (data = "[DONE]" ))
@@ -268,6 +269,11 @@ async def create_completion(
268269 llama_proxy = await run_in_threadpool (
269270 lambda : exit_stack .enter_context (contextlib .contextmanager (get_llama_proxy )())
270271 )
272+ if llama_proxy is None :
273+ raise HTTPException (
274+ status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
275+ detail = "Service is not available" ,
276+ )
271277 if isinstance (body .prompt , list ):
272278 assert len (body .prompt ) <= 1
273279 body .prompt = body .prompt [0 ] if len (body .prompt ) > 0 else ""
@@ -409,7 +415,7 @@ async def create_chat_completion(
409415 {"role" : "system" , "content" : "You are a helpful assistant." },
410416 {"role" : "user" , "content" : "Who won the world series in 2020" },
411417 ],
412- "response_format" : { "type" : "json_object" }
418+ "response_format" : {"type" : "json_object" },
413419 },
414420 },
415421 "tool_calling" : {
@@ -434,15 +440,15 @@ async def create_chat_completion(
434440 },
435441 "required" : ["name" , "age" ],
436442 },
437- }
443+ },
438444 }
439445 ],
440446 "tool_choice" : {
441447 "type" : "function" ,
442448 "function" : {
443449 "name" : "User" ,
444- }
445- }
450+ },
451+ },
446452 },
447453 },
448454 "logprobs" : {
@@ -454,7 +460,7 @@ async def create_chat_completion(
454460 {"role" : "user" , "content" : "What is the capital of France?" },
455461 ],
456462 "logprobs" : True ,
457- "top_logprobs" : 10
463+ "top_logprobs" : 10 ,
458464 },
459465 },
460466 }
@@ -468,6 +474,11 @@ async def create_chat_completion(
468474 llama_proxy = await run_in_threadpool (
469475 lambda : exit_stack .enter_context (contextlib .contextmanager (get_llama_proxy )())
470476 )
477+ if llama_proxy is None :
478+ raise HTTPException (
479+ status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
480+ detail = "Service is not available" ,
481+ )
471482 exclude = {
472483 "n" ,
473484 "logit_bias_type" ,
0 commit comments