Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 129 additions & 69 deletions ldp/nn/graph/llm_call_op.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from __future__ import annotations

import json
import logging
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, ClassVar
from typing import ClassVar

import tree
from aviary.core import MalformedMessageError, Message, Messages
from aviary.tools import Tool, ToolCall, ToolRequestMessage
from aviary.core import (
MalformedMessageError,
Message,
Messages,
ToolCall,
ToolRequestMessage,
Tools,
)
from transformers import LogitsProcessorList

from ldp.graph.gradient_estimators import assign_constant_grads
Expand All @@ -23,71 +31,76 @@
)
from ..lm_config import LMConfig # noqa: TID252

logger = logging.getLogger(__name__)

class LocalLLMCallOp(Op[Message]):
"""An Op that samples a token sequence from a local language model."""

CTX_INPUTS_PREP_KEY: ClassVar[str] = "inputs_prepared"
CTX_TOOLS_PREP_KEY: ClassVar[str] = "tools_prepared"
CTX_OUTPUT_PREP_KEY: ClassVar[str] = "outputs_prepared"

model_name: str
class MessageAndToolParser(ABC):
"""Base class to define how we translate between (messages, tools) and strings."""

def __init__(
self,
model_config: LMConfig,
batch_size: int = 1,
max_wait_interval: float = 0.1,
parallel_mode_config: ParallelModeConfig | None = None,
) -> None:
super().__init__()
supported_templates: ClassVar[set[str]] = set()

pad_token_id = model_config.get_tokenizer().pad_token_id
@classmethod
@abstractmethod
def get_message_content(cls, msg: Message) -> str | None:
"""Represents a message as a string."""

handler_config = TransformerHandlerConfig(
# configurable
lm_config=model_config,
batch_size=batch_size,
max_wait_interval=max_wait_interval,
parallel_mode_config=parallel_mode_config,
# constant configuration
lm_type=LMType.GENERATION,
module_call_fn=AsyncTransformerInterface.model_generate,
collate_fn=partial(
collate_fn_transformer_left_pad, pad_token_id=pad_token_id
),
decollate_fn=decollate_fn_transformer_decoder,
)
self.model_handler = handler_config.make_async_module()
self.model_name = model_config.model
@classmethod
@abstractmethod
def prep_tools_for_tokenizer(cls, tools: Tools | None) -> list[dict] | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are going to prepare tools, why are we then passing None tools?

Just pointing out, Tools | None doesn't really make sense, imo this should be something like *tools: Tool or tools: Tools

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because tools can be None here:

tools: list[Tool] | None = None,
- this method just mirrors the typing of the arguments that can be passed to [Local]LLMCallOp.

"""Prepares tools for tokenization."""

self.llm_call_kwargs = {"logits_processor": LogitsProcessorList()}
@classmethod
@abstractmethod
def parse_tool_request_message(
cls, out_text: str, tools: Tools
) -> ToolRequestMessage:
"""Parses the output text from a tool request message."""

@staticmethod
def prep_messages_for_tokenizer(xi: Messages) -> list[dict]:
@classmethod
def prep_messages_for_tokenizer(cls, msgs: Messages) -> list[dict]:
"""Prepares message history for tokenization."""
result: list[dict] = []
for msg in xi:
content = msg.content
if isinstance(msg, ToolRequestMessage):
assert len(msg.tool_calls) == 1, (
"Support parsing only single tool call for now"
)
tool_call = msg.tool_calls[0]
# TODO: document where this format is coming from. Is this a Huggingface chat template syntax?
content_dict = {
"name": tool_call.function.name,
"parameters": tool_call.function.arguments,
"thought": msg.content,
}
content = json.dumps(content_dict)
assert content is not None, "content is None, doesn't make sense"

for msg in msgs:
content = cls.get_message_content(msg)
assert content is not None, f"Content should not be None: {msg!r}"
result.append({"role": msg.role, "content": content})
return result

@staticmethod
def prep_tools_for_tokenizer(tools: list[Tool] | None) -> list[dict] | None:
"""Prepare tools for the tokenizer by transforming them into a JSON schema."""

class Llama31Parser(MessageAndToolParser):
"""Follows the Llama 3.1 syntax.

See details:
https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#-tool-calling-(8b/70b/405b)-
"""

supported_templates: ClassVar[set[str]] = {
"llama2_chat_template_ori.jinja",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No one but us is going to know what an Ori template is. Any chance you can rename this file, or add a comment explaining it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's on my to-do list to rename this, and there's documentation in chat_templates/README.md for all of the chat templates. I'll leave it for a later PR to sort this out.

"llama3.1_chat_template_hf.jinja",
"llama3.1_chat_template_nothought.jinja",
"llama3.1_chat_template_thought.jinja",
"llama3.1_chat_template_vllm.jinja",
"llama3_chat_template_ori.jinja",
}

@classmethod
def get_message_content(cls, msg: Message) -> str | None:
if isinstance(msg, ToolRequestMessage):
assert len(msg.tool_calls) == 1, (
"Support parsing only single tool call for now"
)
Comment on lines +89 to +91
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is about supporting something, I would rather see raise NotImplementedError than assert

tool_call = msg.tool_calls[0]
content_dict = {
"name": tool_call.function.name,
"parameters": tool_call.function.arguments,
"thought": msg.content,
}
return json.dumps(content_dict)

return msg.content

@classmethod
def prep_tools_for_tokenizer(cls, tools: Tools | None) -> list[dict] | None:
if not tools:
return None

Expand All @@ -112,13 +125,10 @@ def prep_tools_for_tokenizer(tools: list[Tool] | None) -> list[dict] | None:
for tool in tools
]

@staticmethod
def _parse_tool_request(out_text: str, tools: list[Tool]) -> ToolRequestMessage:
"""Parse the output text to extract the tool request.

TODO: see if this needs to be configurable, e.g. for different model
output formats that we want to experiment with.
"""
@classmethod
def parse_tool_request_message(
cls, out_text: str, tools: Tools
) -> ToolRequestMessage:
try:
tool_request = json.loads(out_text)
tool_name = tool_request["name"]
Expand All @@ -136,16 +146,66 @@ def _parse_tool_request(out_text: str, tools: list[Tool]) -> ToolRequestMessage:
except json.JSONDecodeError as err:
raise ValueError(f"Failed to parse tools call message: {out_text}") from err


class LocalLLMCallOp(Op[Message]):
"""An Op that samples a token sequence from a local language model."""

CTX_INPUTS_PREP_KEY: ClassVar[str] = "inputs_prepared"
CTX_TOOLS_PREP_KEY: ClassVar[str] = "tools_prepared"
CTX_OUTPUT_PREP_KEY: ClassVar[str] = "outputs_prepared"

model_name: str

def __init__(
self,
model_config: LMConfig,
batch_size: int = 1,
max_wait_interval: float = 0.1,
parallel_mode_config: ParallelModeConfig | None = None,
parser: type[MessageAndToolParser] = Llama31Parser,
) -> None:
super().__init__()

pad_token_id = model_config.get_tokenizer().pad_token_id

handler_config = TransformerHandlerConfig(
# configurable
lm_config=model_config,
batch_size=batch_size,
max_wait_interval=max_wait_interval,
parallel_mode_config=parallel_mode_config,
# constant configuration
lm_type=LMType.GENERATION,
module_call_fn=AsyncTransformerInterface.model_generate,
collate_fn=partial(
collate_fn_transformer_left_pad, pad_token_id=pad_token_id
),
decollate_fn=decollate_fn_transformer_decoder,
)
self.model_handler = handler_config.make_async_module()
self.model_name = model_config.model

self.prep_messages_for_tokenizer = parser.prep_messages_for_tokenizer
self.prep_tools_for_tokenizer = parser.prep_tools_for_tokenizer
self.parse_tool_request_message = parser.parse_tool_request_message
if model_config.chat_template not in parser.supported_templates:
logger.warning(
f"Chat template {model_config.chat_template!r} not in "
f"{parser.__class__.__name__}.supported templates."
)

self.llm_call_kwargs = {"logits_processor": LogitsProcessorList()}

async def forward(
self,
xi: list[Message],
msgs: list[Message],
temperature: float = 1.0,
max_new_tokens: int = 10,
tools: list[Tool] | None = None,
**kwargs: dict[str, Any],
tools: Tools | None = None,
**kwargs,
) -> Message:
call_id = get_call_id()
inputs = self.prep_messages_for_tokenizer(xi)
inputs = self.prep_messages_for_tokenizer(msgs)
tools_json = self.prep_tools_for_tokenizer(tools)
if get_training_mode():
self.ctx.update(call_id, LocalLLMCallOp.CTX_INPUTS_PREP_KEY, inputs)
Expand All @@ -166,7 +226,7 @@ async def forward(

out_msg = Message(role="assistant", content=out_text)
if tools and out_text.startswith("{"):
out_msg = self._parse_tool_request(out_text, tools)
out_msg = self.parse_tool_request_message(out_text, tools)

if get_training_mode():
self.ctx.update(
Expand Down
10 changes: 5 additions & 5 deletions ldp/nn/handlers/transformer_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@
TParams = ParamSpec("TParams")


def is_conversation(messages) -> bool:
"""Check if messages is an instance of Conversation."""
return isinstance(messages, list) and all(
def is_message_history(maybe_messages) -> bool:
"""Check if input is a message history encoded as list of dict[str, str]."""
return isinstance(maybe_messages, list) and all(
isinstance(msg, dict)
and all(
isinstance(key, str) and isinstance(value, str)
for key, value in msg.items()
)
for msg in messages
for msg in maybe_messages
)


Expand Down Expand Up @@ -894,7 +894,7 @@ def _get_tokenized_inputs(
return BatchEncoding(inputs)
if isinstance(inputs, str):
return tokenizer(inputs, return_tensors="pt")
if is_conversation(inputs):
if is_message_history(inputs):
return tokenizer.apply_chat_template(
inputs,
tools=tools_json,
Expand Down
Loading