Alternative implementation of thinking mode#1723
Conversation
| def get_json_schema_logits_processor( | ||
| backend_name: str | None, | ||
| model: SteerableModel, |
There was a problem hiding this comment.
I'd prefer a separate function that calls get_json_schema_logits_processor instead of the current branching logic
There was a problem hiding this comment.
Also, is there a clean way to get the JsonSchema, Regex and CFG objects up to this point? That would allow us to have a single function get_thinking_logits_processor that dispatches depending on the type.
| return backend_logits_processor | ||
|
|
||
|
|
||
| def get_regex_logits_processor( |
| return backend_logits_processor | ||
|
|
||
|
|
||
| def get_cfg_logits_processor( |
|
|
||
| def _bias_logits_mlx( # pragma: no cover | ||
| self, batch_size: int, logits: TensorType | ||
| self, batch_size: int, logits: TensorType, skip: list[bool] |
There was a problem hiding this comment.
If we go with this design, I would consider a different name like passthrough
| if all(self._is_thinking): | ||
| return logits | ||
|
|
||
| return self.logits_processor.process_logits(input_ids, logits) |
There was a problem hiding this comment.
I'm wondering if we could transform all this into operations on arrays so we don't have to call process_logits for the sequences where the end-of-think token has not been generated. It would go as:
- Extract sequences where end-of-think is present
- Run process-logits on them
- Re-build the logits array with all sequences.
What do you think?
There was a problem hiding this comment.
That would be the best although it means the downstream logits processor needs to be able to handle tensors of different batch sizes and not always in the same order. I'm going to look into how constraining it is.
No description provided.