-
Notifications
You must be signed in to change notification settings - Fork 14
Configurable message parsing #223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
33f898f
572bba0
1e7382d
402fd89
f6020e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,20 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import json | ||
| import logging | ||
| from abc import ABC, abstractmethod | ||
| from functools import partial | ||
| from typing import Any, ClassVar | ||
| from typing import ClassVar | ||
|
|
||
| import tree | ||
| from aviary.core import MalformedMessageError, Message, Messages | ||
| from aviary.tools import Tool, ToolCall, ToolRequestMessage | ||
| from aviary.core import ( | ||
| MalformedMessageError, | ||
| Message, | ||
| Messages, | ||
| ToolCall, | ||
| ToolRequestMessage, | ||
| Tools, | ||
| ) | ||
| from transformers import LogitsProcessorList | ||
|
|
||
| from ldp.graph.gradient_estimators import assign_constant_grads | ||
|
|
@@ -23,71 +31,76 @@ | |
| ) | ||
| from ..lm_config import LMConfig # noqa: TID252 | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| class LocalLLMCallOp(Op[Message]): | ||
| """An Op that samples a token sequence from a local language model.""" | ||
|
|
||
| CTX_INPUTS_PREP_KEY: ClassVar[str] = "inputs_prepared" | ||
| CTX_TOOLS_PREP_KEY: ClassVar[str] = "tools_prepared" | ||
| CTX_OUTPUT_PREP_KEY: ClassVar[str] = "outputs_prepared" | ||
|
|
||
| model_name: str | ||
| class MessageAndToolParser(ABC): | ||
| """Base class to define how we translate between (messages, tools) and strings.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_config: LMConfig, | ||
| batch_size: int = 1, | ||
| max_wait_interval: float = 0.1, | ||
| parallel_mode_config: ParallelModeConfig | None = None, | ||
| ) -> None: | ||
| super().__init__() | ||
| supported_templates: ClassVar[set[str]] = set() | ||
|
|
||
| pad_token_id = model_config.get_tokenizer().pad_token_id | ||
| @classmethod | ||
| @abstractmethod | ||
| def get_message_content(cls, msg: Message) -> str | None: | ||
| """Represents a message as a string.""" | ||
|
|
||
| handler_config = TransformerHandlerConfig( | ||
| # configurable | ||
| lm_config=model_config, | ||
| batch_size=batch_size, | ||
| max_wait_interval=max_wait_interval, | ||
| parallel_mode_config=parallel_mode_config, | ||
| # constant configuration | ||
| lm_type=LMType.GENERATION, | ||
| module_call_fn=AsyncTransformerInterface.model_generate, | ||
| collate_fn=partial( | ||
| collate_fn_transformer_left_pad, pad_token_id=pad_token_id | ||
| ), | ||
| decollate_fn=decollate_fn_transformer_decoder, | ||
| ) | ||
| self.model_handler = handler_config.make_async_module() | ||
| self.model_name = model_config.model | ||
| @classmethod | ||
| @abstractmethod | ||
| def prep_tools_for_tokenizer(cls, tools: Tools | None) -> list[dict] | None: | ||
| """Prepares tools for tokenization.""" | ||
|
|
||
| self.llm_call_kwargs = {"logits_processor": LogitsProcessorList()} | ||
| @classmethod | ||
| @abstractmethod | ||
| def parse_tool_request_message( | ||
| cls, out_text: str, tools: Tools | ||
| ) -> ToolRequestMessage: | ||
| """Parses the output text from a tool request message.""" | ||
|
|
||
| @staticmethod | ||
| def prep_messages_for_tokenizer(xi: Messages) -> list[dict]: | ||
| @classmethod | ||
| def prep_messages_for_tokenizer(cls, msgs: Messages) -> list[dict]: | ||
| """Prepares message history for tokenization.""" | ||
| result: list[dict] = [] | ||
| for msg in xi: | ||
| content = msg.content | ||
| if isinstance(msg, ToolRequestMessage): | ||
| assert len(msg.tool_calls) == 1, ( | ||
| "Support parsing only single tool call for now" | ||
| ) | ||
| tool_call = msg.tool_calls[0] | ||
| # TODO: document where this format is coming from. Is this a Huggingface chat template syntax? | ||
| content_dict = { | ||
| "name": tool_call.function.name, | ||
| "parameters": tool_call.function.arguments, | ||
| "thought": msg.content, | ||
| } | ||
| content = json.dumps(content_dict) | ||
| assert content is not None, "content is None, doesn't make sense" | ||
|
|
||
| for msg in msgs: | ||
| content = cls.get_message_content(msg) | ||
| assert content is not None, f"Content should not be None: {msg!r}" | ||
| result.append({"role": msg.role, "content": content}) | ||
| return result | ||
|
|
||
| @staticmethod | ||
| def prep_tools_for_tokenizer(tools: list[Tool] | None) -> list[dict] | None: | ||
| """Prepare tools for the tokenizer by transforming them into a JSON schema.""" | ||
|
|
||
| class Llama31Parser(MessageAndToolParser): | ||
| """Follows the Llama 3.1 syntax. | ||
|
|
||
| See details: | ||
| https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#-tool-calling-(8b/70b/405b)- | ||
| """ | ||
|
|
||
| supported_templates: ClassVar[set[str]] = { | ||
| "llama2_chat_template_ori.jinja", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No one but us is going to know what an Ori template is. Any chance you can rename this file, or add a comment explaining it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it's on my to-do list to rename this, and there's documentation in |
||
| "llama3.1_chat_template_hf.jinja", | ||
| "llama3.1_chat_template_nothought.jinja", | ||
| "llama3.1_chat_template_thought.jinja", | ||
| "llama3.1_chat_template_vllm.jinja", | ||
| "llama3_chat_template_ori.jinja", | ||
| } | ||
|
|
||
| @classmethod | ||
| def get_message_content(cls, msg: Message) -> str | None: | ||
| if isinstance(msg, ToolRequestMessage): | ||
| assert len(msg.tool_calls) == 1, ( | ||
| "Support parsing only single tool call for now" | ||
| ) | ||
|
Comment on lines
+89
to
+91
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is about supporting something, I would rather see |
||
| tool_call = msg.tool_calls[0] | ||
| content_dict = { | ||
| "name": tool_call.function.name, | ||
| "parameters": tool_call.function.arguments, | ||
| "thought": msg.content, | ||
| } | ||
| return json.dumps(content_dict) | ||
|
|
||
| return msg.content | ||
|
|
||
| @classmethod | ||
| def prep_tools_for_tokenizer(cls, tools: Tools | None) -> list[dict] | None: | ||
| if not tools: | ||
| return None | ||
|
|
||
|
|
@@ -112,13 +125,10 @@ def prep_tools_for_tokenizer(tools: list[Tool] | None) -> list[dict] | None: | |
| for tool in tools | ||
| ] | ||
|
|
||
| @staticmethod | ||
| def _parse_tool_request(out_text: str, tools: list[Tool]) -> ToolRequestMessage: | ||
| """Parse the output text to extract the tool request. | ||
|
|
||
| TODO: see if this needs to be configurable, e.g. for different model | ||
| output formats that we want to experiment with. | ||
| """ | ||
| @classmethod | ||
| def parse_tool_request_message( | ||
| cls, out_text: str, tools: Tools | ||
| ) -> ToolRequestMessage: | ||
| try: | ||
| tool_request = json.loads(out_text) | ||
| tool_name = tool_request["name"] | ||
|
|
@@ -136,16 +146,66 @@ def _parse_tool_request(out_text: str, tools: list[Tool]) -> ToolRequestMessage: | |
| except json.JSONDecodeError as err: | ||
| raise ValueError(f"Failed to parse tools call message: {out_text}") from err | ||
|
|
||
|
|
||
| class LocalLLMCallOp(Op[Message]): | ||
| """An Op that samples a token sequence from a local language model.""" | ||
|
|
||
| CTX_INPUTS_PREP_KEY: ClassVar[str] = "inputs_prepared" | ||
| CTX_TOOLS_PREP_KEY: ClassVar[str] = "tools_prepared" | ||
| CTX_OUTPUT_PREP_KEY: ClassVar[str] = "outputs_prepared" | ||
|
|
||
| model_name: str | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_config: LMConfig, | ||
| batch_size: int = 1, | ||
| max_wait_interval: float = 0.1, | ||
| parallel_mode_config: ParallelModeConfig | None = None, | ||
| parser: type[MessageAndToolParser] = Llama31Parser, | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| pad_token_id = model_config.get_tokenizer().pad_token_id | ||
|
|
||
| handler_config = TransformerHandlerConfig( | ||
| # configurable | ||
| lm_config=model_config, | ||
| batch_size=batch_size, | ||
| max_wait_interval=max_wait_interval, | ||
| parallel_mode_config=parallel_mode_config, | ||
| # constant configuration | ||
| lm_type=LMType.GENERATION, | ||
| module_call_fn=AsyncTransformerInterface.model_generate, | ||
| collate_fn=partial( | ||
| collate_fn_transformer_left_pad, pad_token_id=pad_token_id | ||
| ), | ||
| decollate_fn=decollate_fn_transformer_decoder, | ||
| ) | ||
| self.model_handler = handler_config.make_async_module() | ||
| self.model_name = model_config.model | ||
|
|
||
| self.prep_messages_for_tokenizer = parser.prep_messages_for_tokenizer | ||
| self.prep_tools_for_tokenizer = parser.prep_tools_for_tokenizer | ||
| self.parse_tool_request_message = parser.parse_tool_request_message | ||
| if model_config.chat_template not in parser.supported_templates: | ||
| logger.warning( | ||
| f"Chat template {model_config.chat_template!r} not in " | ||
| f"{parser.__class__.__name__}.supported templates." | ||
| ) | ||
|
|
||
| self.llm_call_kwargs = {"logits_processor": LogitsProcessorList()} | ||
|
|
||
| async def forward( | ||
| self, | ||
| xi: list[Message], | ||
| msgs: list[Message], | ||
| temperature: float = 1.0, | ||
| max_new_tokens: int = 10, | ||
| tools: list[Tool] | None = None, | ||
| **kwargs: dict[str, Any], | ||
| tools: Tools | None = None, | ||
| **kwargs, | ||
| ) -> Message: | ||
| call_id = get_call_id() | ||
| inputs = self.prep_messages_for_tokenizer(xi) | ||
| inputs = self.prep_messages_for_tokenizer(msgs) | ||
| tools_json = self.prep_tools_for_tokenizer(tools) | ||
| if get_training_mode(): | ||
| self.ctx.update(call_id, LocalLLMCallOp.CTX_INPUTS_PREP_KEY, inputs) | ||
|
|
@@ -166,7 +226,7 @@ async def forward( | |
|
|
||
| out_msg = Message(role="assistant", content=out_text) | ||
| if tools and out_text.startswith("{"): | ||
| out_msg = self._parse_tool_request(out_text, tools) | ||
| out_msg = self.parse_tool_request_message(out_text, tools) | ||
|
|
||
| if get_training_mode(): | ||
| self.ctx.update( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are going to prepare tools, why are we then passing
Nonetools?Just pointing out,
Tools | Nonedoesn't really make sense, imo this should be something like*tools: Toolortools: ToolsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because
toolscan beNonehere:ldp/ldp/graph/common_ops.py
Line 256 in 0d0e0bd
[Local]LLMCallOp.