fix: serialize disagg first_gen_log_probs int keys for Rust transport#7145
Conversation
|
👋 Hi nv-yna! Thank you for contributing to ai-dynamo/dynamo. Just a reminder: The 🚀 |
a5663fd to
da8a7bd
Compare
da8a7bd to
2ea183a
Compare
WalkthroughThis pull request adds serialization and deserialization logic for first_gen_log_probs in the disaggregated request handling pipeline. New codec methods transform log probabilities between TRT-LLM's internal dictionary format and a transport-compatible list format, with handler integration and comprehensive test coverage. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tests/serve/test_disagg_logprobs_serialization.py (1)
104-107: Addstrict=Truetozip()for safer iteration.The static analysis flagged this
zip()call. Sinceoriginalandrecoveredshould have equal lengths after a successful round-trip, addingstrict=Truewould catch any unexpected length mismatch.🔧 Proposed fix
- for orig_pos, rec_pos in zip(original, recovered): + for orig_pos, rec_pos in zip(original, recovered, strict=True):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/serve/test_disagg_logprobs_serialization.py` around lines 104 - 107, The zip over original and recovered should be strict to catch length mismatches; update the loop that iterates with "for orig_pos, rec_pos in zip(original, recovered):" to use "zip(original, recovered, strict=True)" so any unexpected unequal lengths after the round-trip fail fast; keep the loop body (asserting rec_pos[tid].logprob and rank) unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/serve/test_disagg_logprobs_serialization.py`:
- Around line 36-38: The test class TestDisaggLogprobsSerializationRoundtrip
currently has only two markers; add the required GPU and type markers by
annotating the class with `@pytest.mark.gpu_0` and `@pytest.mark.unit` (in addition
to the existing `@pytest.mark.pre_merge` and `@pytest.mark.trtllm`) so the class has
a scheduling marker, a GPU marker, and a type marker per guidelines.
- Around line 30-33: The test imports optional dependencies Logprob and
DisaggregatedParamsCodec directly; wrap those imports in a try/except
ImportError at module import time and call pytest.skip("missing optional
dependency: tensorrt_llm", allow_module_level=True) inside the except so the
test module is not collected when tensorrt_llm (and its Logprob) or disagg utils
are absent; update the import block referencing Logprob and
DisaggregatedParamsCodec accordingly to perform the guarded import and early
skip.
---
Nitpick comments:
In `@tests/serve/test_disagg_logprobs_serialization.py`:
- Around line 104-107: The zip over original and recovered should be strict to
catch length mismatches; update the loop that iterates with "for orig_pos,
rec_pos in zip(original, recovered):" to use "zip(original, recovered,
strict=True)" so any unexpected unequal lengths after the round-trip fail fast;
keep the loop body (asserting rec_pos[tid].logprob and rank) unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 8e1547d0-90eb-41c6-99b3-2453c325b748
📒 Files selected for processing (3)
components/src/dynamo/trtllm/request_handlers/handler_base.pycomponents/src/dynamo/trtllm/utils/disagg_utils.pytests/serve/test_disagg_logprobs_serialization.py
2ea183a to
6d3fae5
Compare
indrajit96
left a comment
There was a problem hiding this comment.
LGTM!
Some minor comments for test
6d3fae5 to
69ffb5c
Compare
TRT-LLM PR #11727 adds first_gen_log_probs to DisaggregatedParams to
carry the first generated token's logprobs from prefill to decode in
disaggregated serving. This field uses dicts with integer token-ID keys
(e.g. {4710: Logprob(...)}).
After dataclasses.asdict(), these int keys break pythonize 0.23's
depythonize which requires string map keys for serde_json::Value
(dict_key_not_string error). This cascades into a "Disconnected: Stream
ended before generation completed" error because the error response
can't be published before the stream context stops.
Add serialize/deserialize methods to DisaggregatedParamsCodec that
convert between TRT-LLM's internal {int: Logprob} format and a
JSON-safe list-of-dicts transport format, matching TRT-LLM's own
_serialize_first_gen_log_probs in openai_protocol.py.
Add unit tests for the serialization round-trip.
Fixes: DYN-2265
Signed-off-by: Yuewei Na <nv-yna@users.noreply.github.com>
aa905a3 to
afb12a4
Compare
Co-authored-by: Dmitry Tokarev <dtokarev@nvidia.com> Signed-off-by: Yuewei Na <248773860+nv-yna@users.noreply.github.com>
Co-authored-by: Dmitry Tokarev <dtokarev@nvidia.com> Signed-off-by: Yuewei Na <248773860+nv-yna@users.noreply.github.com>
…ai-dynamo#7145) Signed-off-by: Yuewei Na <nv-yna@users.noreply.github.com> Co-authored-by: Yuewei Na <nv-yna@users.noreply.github.com>
Summary
test_deployment[disaggregated_logprobs-2])first_gen_log_probstoDisaggregatedParamswith integer token-ID dict keys (e.g.{4710: Logprob(...)})dataclasses.asdict(), these int keys break Rustpythonize 0.23depythonize(dict_key_not_stringerror) → cascades intoDisconnected: Stream ended before generation completedserialize_first_gen_log_probs/deserialize_first_gen_log_probstoDisaggregatedParamsCodecusing TRT-LLM own list-of-dicts format ([{"token_id": id, "logprob": float, "rank": int}])Test plan
test_deployment[disaggregated_logprobs-2]passes with fix (all logprobs payloads validated — 300 tokens with logprobs each)test_deployment[disaggregated-2]passes (no regression on non-logprobs disagg)Fixes: DYN-2265 / NVBugs 5926823
Summary by CodeRabbit
Bug Fixes
Tests