Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
nits
  • Loading branch information
Ori Kabeli committed Mar 19, 2025
commit bddaa92f11cc69cd86ff974e474f73faa6967dde
8 changes: 6 additions & 2 deletions src/ldp/alg/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ async def sample_trajectories(self, **kwargs):
kwargs["environment_factory"],
kwargs.get("batch_size", 1),
kwargs.get("max_steps"),
log_exceptions_immediately=kwargs.get("log_exceptions_immediately", True)
log_exceptions_immediately=kwargs.get(
"log_exceptions_immediately", True
),
)

if "environments" in kwargs:
Expand All @@ -131,7 +133,9 @@ async def sample_trajectories(self, **kwargs):
return await self._sample_trajectories_from_envs(
kwargs["environments"],
kwargs.get("max_steps"),
log_exceptions_immediately=kwargs.get("log_exceptions_immediately", True),
log_exceptions_immediately=kwargs.get(
"log_exceptions_immediately", True
),
)

raise TypeError(
Expand Down
10 changes: 5 additions & 5 deletions src/ldp/nn/agent/simple_local_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class AgentLMConfig(_LMConfig):
),
validate_default=True,
)
max_traj_token_count: int | None = Field(
max_messages_token_count: int | None = Field(
default=None,
description="If set, raise an error if the total tokens in the trajectory exceed this value.",
)
Expand Down Expand Up @@ -128,7 +128,7 @@ async def get_asv(

def _validate_token_count(self, messages: list[Message], tools: list[Tool]):
"""Asserts token count for the trajectory is within the limit."""
if self.llm_model.max_traj_token_count is None:
if self.llm_model.max_messages_token_count is None:
return
messages_for_tokenizer = self._llm_call_op.prep_messages_for_tokenizer(messages)
tools_for_tokenizer = self._llm_call_op.prep_tools_for_tokenizer(tools)
Expand All @@ -138,12 +138,12 @@ def _validate_token_count(self, messages: list[Message], tools: list[Tool]):
messages=messages_for_tokenizer,
tools=tools_for_tokenizer, # type: ignore[arg-type]
)
if total_tokens > self.llm_model.max_traj_token_count:
if total_tokens > self.llm_model.max_messages_token_count:
logger.error(
f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_traj_token_count}"
f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_messages_token_count}"
)
raise ValueError(
f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_traj_token_count}"
f"Token limit exceeded for trajectory: {total_tokens} > {self.llm_model.max_messages_token_count}"
)

# TODO: maybe remove these recomputation methods. I added them to debug some things. But idk,
Expand Down
Loading