diff --git a/ldp/nn/graph/llm_call_op.py b/ldp/nn/graph/llm_call_op.py index 88b727ef..9a5426e5 100644 --- a/ldp/nn/graph/llm_call_op.py +++ b/ldp/nn/graph/llm_call_op.py @@ -1,12 +1,20 @@ from __future__ import annotations import json +import logging +from abc import ABC, abstractmethod from functools import partial -from typing import Any, ClassVar +from typing import ClassVar import tree -from aviary.core import MalformedMessageError, Message, Messages -from aviary.tools import Tool, ToolCall, ToolRequestMessage +from aviary.core import ( + MalformedMessageError, + Message, + Messages, + ToolCall, + ToolRequestMessage, + Tools, +) from transformers import LogitsProcessorList from ldp.graph.gradient_estimators import assign_constant_grads @@ -23,71 +31,76 @@ ) from ..lm_config import LMConfig # noqa: TID252 +logger = logging.getLogger(__name__) -class LocalLLMCallOp(Op[Message]): - """An Op that samples a token sequence from a local language model.""" - - CTX_INPUTS_PREP_KEY: ClassVar[str] = "inputs_prepared" - CTX_TOOLS_PREP_KEY: ClassVar[str] = "tools_prepared" - CTX_OUTPUT_PREP_KEY: ClassVar[str] = "outputs_prepared" - model_name: str +class MessageAndToolParser(ABC): + """Base class to define how we translate between (messages, tools) and strings.""" - def __init__( - self, - model_config: LMConfig, - batch_size: int = 1, - max_wait_interval: float = 0.1, - parallel_mode_config: ParallelModeConfig | None = None, - ) -> None: - super().__init__() + supported_templates: ClassVar[set[str]] = set() - pad_token_id = model_config.get_tokenizer().pad_token_id + @classmethod + @abstractmethod + def get_message_content(cls, msg: Message) -> str | None: + """Represents a message as a string.""" - handler_config = TransformerHandlerConfig( - # configurable - lm_config=model_config, - batch_size=batch_size, - max_wait_interval=max_wait_interval, - parallel_mode_config=parallel_mode_config, - # constant configuration - lm_type=LMType.GENERATION, - module_call_fn=AsyncTransformerInterface.model_generate, - collate_fn=partial( - collate_fn_transformer_left_pad, pad_token_id=pad_token_id - ), - decollate_fn=decollate_fn_transformer_decoder, - ) - self.model_handler = handler_config.make_async_module() - self.model_name = model_config.model + @classmethod + @abstractmethod + def prep_tools_for_tokenizer(cls, tools: Tools | None) -> list[dict] | None: + """Prepares tools for tokenization.""" - self.llm_call_kwargs = {"logits_processor": LogitsProcessorList()} + @classmethod + @abstractmethod + def parse_tool_request_message( + cls, out_text: str, tools: Tools + ) -> ToolRequestMessage: + """Parses the output text from a tool request message.""" - @staticmethod - def prep_messages_for_tokenizer(xi: Messages) -> list[dict]: + @classmethod + def prep_messages_for_tokenizer(cls, msgs: Messages) -> list[dict]: + """Prepares message history for tokenization.""" result: list[dict] = [] - for msg in xi: - content = msg.content - if isinstance(msg, ToolRequestMessage): - assert len(msg.tool_calls) == 1, ( - "Support parsing only single tool call for now" - ) - tool_call = msg.tool_calls[0] - # TODO: document where this format is coming from. Is this a Huggingface chat template syntax? - content_dict = { - "name": tool_call.function.name, - "parameters": tool_call.function.arguments, - "thought": msg.content, - } - content = json.dumps(content_dict) - assert content is not None, "content is None, doesn't make sense" - + for msg in msgs: + content = cls.get_message_content(msg) + assert content is not None, f"Content should not be None: {msg!r}" result.append({"role": msg.role, "content": content}) return result - @staticmethod - def prep_tools_for_tokenizer(tools: list[Tool] | None) -> list[dict] | None: - """Prepare tools for the tokenizer by transforming them into a JSON schema.""" + +class Llama31Parser(MessageAndToolParser): + """Follows the Llama 3.1 syntax. + + See details: + https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#-tool-calling-(8b/70b/405b)- + """ + + supported_templates: ClassVar[set[str]] = { + "llama2_chat_template_ori.jinja", + "llama3.1_chat_template_hf.jinja", + "llama3.1_chat_template_nothought.jinja", + "llama3.1_chat_template_thought.jinja", + "llama3.1_chat_template_vllm.jinja", + "llama3_chat_template_ori.jinja", + } + + @classmethod + def get_message_content(cls, msg: Message) -> str | None: + if isinstance(msg, ToolRequestMessage): + assert len(msg.tool_calls) == 1, ( + "Support parsing only single tool call for now" + ) + tool_call = msg.tool_calls[0] + content_dict = { + "name": tool_call.function.name, + "parameters": tool_call.function.arguments, + "thought": msg.content, + } + return json.dumps(content_dict) + + return msg.content + + @classmethod + def prep_tools_for_tokenizer(cls, tools: Tools | None) -> list[dict] | None: if not tools: return None @@ -112,13 +125,10 @@ def prep_tools_for_tokenizer(tools: list[Tool] | None) -> list[dict] | None: for tool in tools ] - @staticmethod - def _parse_tool_request(out_text: str, tools: list[Tool]) -> ToolRequestMessage: - """Parse the output text to extract the tool request. - - TODO: see if this needs to be configurable, e.g. for different model - output formats that we want to experiment with. - """ + @classmethod + def parse_tool_request_message( + cls, out_text: str, tools: Tools + ) -> ToolRequestMessage: try: tool_request = json.loads(out_text) tool_name = tool_request["name"] @@ -136,16 +146,66 @@ def _parse_tool_request(out_text: str, tools: list[Tool]) -> ToolRequestMessage: except json.JSONDecodeError as err: raise ValueError(f"Failed to parse tools call message: {out_text}") from err + +class LocalLLMCallOp(Op[Message]): + """An Op that samples a token sequence from a local language model.""" + + CTX_INPUTS_PREP_KEY: ClassVar[str] = "inputs_prepared" + CTX_TOOLS_PREP_KEY: ClassVar[str] = "tools_prepared" + CTX_OUTPUT_PREP_KEY: ClassVar[str] = "outputs_prepared" + + model_name: str + + def __init__( + self, + model_config: LMConfig, + batch_size: int = 1, + max_wait_interval: float = 0.1, + parallel_mode_config: ParallelModeConfig | None = None, + parser: type[MessageAndToolParser] = Llama31Parser, + ) -> None: + super().__init__() + + pad_token_id = model_config.get_tokenizer().pad_token_id + + handler_config = TransformerHandlerConfig( + # configurable + lm_config=model_config, + batch_size=batch_size, + max_wait_interval=max_wait_interval, + parallel_mode_config=parallel_mode_config, + # constant configuration + lm_type=LMType.GENERATION, + module_call_fn=AsyncTransformerInterface.model_generate, + collate_fn=partial( + collate_fn_transformer_left_pad, pad_token_id=pad_token_id + ), + decollate_fn=decollate_fn_transformer_decoder, + ) + self.model_handler = handler_config.make_async_module() + self.model_name = model_config.model + + self.prep_messages_for_tokenizer = parser.prep_messages_for_tokenizer + self.prep_tools_for_tokenizer = parser.prep_tools_for_tokenizer + self.parse_tool_request_message = parser.parse_tool_request_message + if model_config.chat_template not in parser.supported_templates: + logger.warning( + f"Chat template {model_config.chat_template!r} not in " + f"{parser.__class__.__name__}.supported templates." + ) + + self.llm_call_kwargs = {"logits_processor": LogitsProcessorList()} + async def forward( self, - xi: list[Message], + msgs: list[Message], temperature: float = 1.0, max_new_tokens: int = 10, - tools: list[Tool] | None = None, - **kwargs: dict[str, Any], + tools: Tools | None = None, + **kwargs, ) -> Message: call_id = get_call_id() - inputs = self.prep_messages_for_tokenizer(xi) + inputs = self.prep_messages_for_tokenizer(msgs) tools_json = self.prep_tools_for_tokenizer(tools) if get_training_mode(): self.ctx.update(call_id, LocalLLMCallOp.CTX_INPUTS_PREP_KEY, inputs) @@ -166,7 +226,7 @@ async def forward( out_msg = Message(role="assistant", content=out_text) if tools and out_text.startswith("{"): - out_msg = self._parse_tool_request(out_text, tools) + out_msg = self.parse_tool_request_message(out_text, tools) if get_training_mode(): self.ctx.update( diff --git a/ldp/nn/handlers/transformer_handler.py b/ldp/nn/handlers/transformer_handler.py index 6fd7b740..8561b355 100644 --- a/ldp/nn/handlers/transformer_handler.py +++ b/ldp/nn/handlers/transformer_handler.py @@ -66,15 +66,15 @@ TParams = ParamSpec("TParams") -def is_conversation(messages) -> bool: - """Check if messages is an instance of Conversation.""" - return isinstance(messages, list) and all( +def is_message_history(maybe_messages) -> bool: + """Check if input is a message history encoded as list of dict[str, str].""" + return isinstance(maybe_messages, list) and all( isinstance(msg, dict) and all( isinstance(key, str) and isinstance(value, str) for key, value in msg.items() ) - for msg in messages + for msg in maybe_messages ) @@ -894,7 +894,7 @@ def _get_tokenized_inputs( return BatchEncoding(inputs) if isinstance(inputs, str): return tokenizer(inputs, return_tensors="pt") - if is_conversation(inputs): + if is_message_history(inputs): return tokenizer.apply_chat_template( inputs, tools=tools_json,