Skip to content
Open
29 changes: 20 additions & 9 deletions examples/llm/components/prefill_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ class RequestType(BaseModel):
text: str


async def wrap_generator(generator):
while True:
try:
await generator.__anext__()
except StopAsyncIteration:
break


Comment on lines +42 to +49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: how about define this inside PrefillWorker class so that if we have a batched generate in the future we can direct swap?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. That would mean we'd also need to bring the PrefillWorker.generate call into the prefill queue, which might get a bit weird.

@service(
dynamo={
"namespace": "dynamo",
Expand Down Expand Up @@ -152,15 +160,18 @@ async def prefill_queue_handler(self):
) as prefill_queue:
logger.info("prefill queue handler started")
while True:
# TODO: this might add a small overhead to pull prefill from nats
# need to test and check how much overhead it is
prefill_request = await prefill_queue.dequeue_prefill_request()
if prefill_request is not None:
logger.info(
f"Dequeued prefill request: {prefill_request.request_id}"
)
async for _ in self.generate(prefill_request):
pass
reqs = await prefill_queue.dequeue_prefill_request_batch(
self.engine_args.max_batched_prefill_tokens,
self.engine_args.block_size,
)

if reqs is not None:
logger.debug(f"Running batch of {len(reqs)} prefill requests")
# Create futures which indicate the completion of each prefill.
futures = [wrap_generator(self.generate(req)) for req in reqs]
# Wait for all prefills to complete.
await asyncio.gather(*futures)

if self.shutdown_requested:
logger.info(
"Shutdown requested, checking if engine has any pending prefill sending requests"
Expand Down
1 change: 1 addition & 0 deletions examples/llm/configs/disagg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ VllmWorker:

PrefillWorker:
max-num-batched-tokens: 16384
max-batched-prefill-tokens: 2048
ServiceArgs:
workers: 1
resources:
Expand Down
1 change: 1 addition & 0 deletions examples/llm/configs/disagg_router.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ VllmWorker:

PrefillWorker:
max-num-batched-tokens: 16384
max-batched-prefill-tokens: 2048
ServiceArgs:
workers: 1
resources:
Expand Down
55 changes: 52 additions & 3 deletions examples/llm/utils/prefill_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.


from typing import Optional
from typing import List, Optional

import msgspec
from utils.nats_queue import NATSQueue
Expand All @@ -39,18 +39,67 @@ def __init__(
dequeue_timeout=dequeue_timeout,
)

self.pending = None

async def enqueue_prefill_request(
self, prefill_request: RemotePrefillRequest
) -> None:
encoded_request = msgspec.json.encode(prefill_request)
await self.enqueue_task(encoded_request)

async def dequeue_prefill_request(self) -> Optional[RemotePrefillRequest]:
encoded_request = await self.dequeue_task()
async def dequeue_prefill_request(
self, timeout: Optional[float] = None
) -> Optional[RemotePrefillRequest]:
encoded_request = await self.dequeue_task(timeout)
if encoded_request is not None:
prefill_request = msgspec.json.decode(
encoded_request, type=RemotePrefillRequest
)
return prefill_request
else:
return None

async def dequeue_prefill_request_batch(
self, max_batched_prefill_tokens: int, block_size: int
) -> Optional[List[RemotePrefillRequest]]:
def num_new_tokens(req: RemotePrefillRequest) -> int:
return len(req.prompt_token_ids) - len(req.computed_block_ids) * block_size

req = (
self.pending
if self.pending is not None
else await self.dequeue_prefill_request()
)

if req is None:
return None

reqs = [req]

# Reset the pending request (if any).
self.pending = None
# Determine how much margin we have for more requests in the same batch.
remaining_prefill_tokens = max_batched_prefill_tokens - num_new_tokens(req)

if remaining_prefill_tokens < 0:
return reqs

# TODO: We might want to double-buffer this process
# to avoid the overhead of dequeuing from nats
prefill_queue_size = await self.get_queue_size()
for _ in range(prefill_queue_size):
# This should be immediate, hence the zero timeout.
req = await self.dequeue_prefill_request(0)

if req is None:
break

if num_new_tokens(req) <= remaining_prefill_tokens:
reqs.append(req)
remaining_prefill_tokens -= num_new_tokens(req)
else:
# We need to save this request for the next batch.
self.pending = req
break

return reqs
7 changes: 7 additions & 0 deletions examples/llm/utils/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
default=3,
help="Maximum queue size for remote prefill. If the prefill queue size is greater than this value, prefill phase of the incoming request will be executed locally.",
)
parser.add_argument(
"--max-batched-prefill-tokens",
type=int,
default=2048,
help="Maximum number of tokens to prefill in a single batch. If the number of tokens to prefill is greater than this value, prefill phase will execute as bs=1.",
)
Comment on lines +74 to +79
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maximum number of tokens to prefill in a single batch

Is this accurate name/description? With nats queue, we can only ensure stop dequeuing when batched prefill tokens is larger than this value. Or, the actual num prefill token is always larger than this value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, this could be the max, where we dequeue an extra one (that we find exceeds the max) and carry it over to the next batch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's a good idea. We should send whatever we dequeue to the engine immediately, or we shouldn't dequeue and hope other prefill workers will pick it up.

How about naming it --min-batched-prefill-tokens and explain that this only applies if the prefill queue has that many tokens.

parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args(vllm_args)
engine_args = AsyncEngineArgs.from_cli_args(args)
Expand All @@ -78,4 +84,5 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
engine_args.conditional_disagg = args.conditional_disagg
engine_args.max_local_prefill_length = args.max_local_prefill_length
engine_args.max_prefill_queue_size = args.max_prefill_queue_size
engine_args.max_batched_prefill_tokens = args.max_batched_prefill_tokens
return engine_args
Loading