Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[chore] Migrate print statements to logging
  • Loading branch information
fjosw committed Feb 5, 2026
commit 34ae9879b6d7d6088b927d01306740929205cc29
2 changes: 1 addition & 1 deletion verl/checkpoint_engine/hccl_checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from verl.utils.net_utils import get_free_port, is_valid_ipv6_address

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion verl/checkpoint_engine/nccl_checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from verl.utils.net_utils import get_free_port, is_valid_ipv6_address

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion verl/checkpoint_engine/nixl_checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from verl.utils.net_utils import get_free_port, is_valid_ipv6_address

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))


@dataclass
Expand Down
6 changes: 3 additions & 3 deletions verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
from verl.utils.transferqueue_utils import tqbridge
from verl.workers.rollout.replica import TokenOutput, get_rollout_replica_class

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))


class AsyncLLMServerManager:
Expand Down Expand Up @@ -901,7 +901,7 @@ def _initialize_llm_servers(self, rollout_resource_pool: RayResourcePool):
self.server_handles = [server._server_handle for server in self.rollout_replicas]
self.server_addresses = [server._server_address for server in self.rollout_replicas]

print(f"AgentLoopManager: {self.server_addresses}")
logger.info(f"AgentLoopManager server addresses: {self.server_addresses}")

# Update Prometheus configuration with server addresses
if rollout_config.prometheus.enable:
Expand Down
16 changes: 9 additions & 7 deletions verl/experimental/agent_loop/prometheus_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from verl.workers.config.rollout import PrometheusConfig

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))


def update_prometheus_config(config: PrometheusConfig, server_addresses: list[str], rollout_name: str | None = None):
Expand Down Expand Up @@ -71,11 +71,11 @@ def reload_prometheus(port):
reload_url = f"http://{ip_address}:{port}/-/reload"

try:
subprocess.run(["curl", "-X", "POST", reload_url], capture_output=True, text=True, timeout=10)
print(f"Reloading Prometheus on node: {reload_url}")
except Exception:
result = subprocess.run(["curl", "-X", "POST", reload_url], capture_output=True, text=True, timeout=10)
logger.debug(f"Prometheus reload request sent to {reload_url}, return code: {result.returncode}")
except Exception as e:
# Skip errors on non-master nodes
pass
logger.debug(f"Failed to reload Prometheus at {reload_url}: {e}")

# Get all available nodes and schedule tasks on each node
nodes = ray.nodes()
Expand All @@ -93,7 +93,9 @@ def reload_prometheus(port):
ray.get(write_tasks)

server_type = rollout_name.upper() if rollout_name else "rollout"
print(f"Updated Prometheus configuration at {config.file} with {len(server_addresses)} {server_type} servers")
logger.info(
f"Updated Prometheus configuration at {config.file} with {len(server_addresses)} {server_type} servers"
)

# Reload Prometheus on all nodes
reload_tasks = []
Expand Down
4 changes: 2 additions & 2 deletions verl/experimental/agent_loop/single_turn_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from verl.tools.utils.tool_registry import initialize_tools_from_config
from verl.utils.profiler import simple_timer

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))


@register("single_turn_agent")
Expand Down
4 changes: 2 additions & 2 deletions verl/experimental/agent_loop/tool_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
from verl.utils.profiler import simple_timer
from verl.utils.rollout_trace import rollout_trace_op

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))


class AgentState(Enum):
Expand Down
4 changes: 2 additions & 2 deletions verl/experimental/agent_loop/tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from verl.utils.ray_utils import get_event_loop
from verl.utils.rollout_trace import rollout_trace_op

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))


class FunctionCall(BaseModel):
Expand Down
22 changes: 11 additions & 11 deletions verl/experimental/fully_async_policy/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
rollout_trace_op,
)

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))


class FullyAsyncLLMServerManager(AsyncLLMServerManager):
Expand Down Expand Up @@ -232,12 +232,12 @@ def __init__(
from verl.experimental.fully_async_policy.sglang_rollout.sglang_async_server import FullyAsyncSGLangReplica

self.rollout_replica_class = FullyAsyncSGLangReplica
print("[FullyAsyncAgentLoopManager] SGLang replica class selected")
logger.info("SGLang replica class selected")
elif rollout_name == "vllm":
from verl.experimental.fully_async_policy.vllm_rollout.vllm_async_server import FullyAsyncvLLMReplica

self.rollout_replica_class = FullyAsyncvLLMReplica
print("[FullyAsyncAgentLoopManager] vLLM replica class selected")
logger.info("vLLM replica class selected")
else:
raise ValueError(f"Unsupported rollout name: {rollout_name}. Supported values are 'sglang' and 'vllm'.")

Expand Down Expand Up @@ -294,7 +294,7 @@ async def _initialize_llm_servers_async(self):
self.server_handles = [server._server_handle for server in self.rollout_replicas]
self.server_addresses = [server._server_address for server in self.rollout_replicas]

print(f"AgentLoopManager: {self.server_addresses}")
logger.info(f"AgentLoopManager server addresses: {self.server_addresses}")
# Update Prometheus configuration with server addresses
if rollout_config.prometheus.enable:
if rollout_config.disable_log_stats:
Expand Down Expand Up @@ -348,26 +348,26 @@ async def sleep(self):
await asyncio.gather(*[replica.sleep() for replica in self.rollout_replicas])

async def reset_prefix_cache(self):
print("[FullyAsyncAgentLoopManager] Reset prefix cache ...")
logger.info("Reset prefix cache ...")
# await asyncio.gather(*[replica.reset_prefix_cache() for replica in self.rollout_replicas])
# Note: debug
timeout = 5.0

async def reset_one(idx, replica):
print(f"[reset_prefix_cache] start replica={idx}")
logger.debug(f"reset_prefix_cache start replica={idx}")
try:
await asyncio.wait_for(replica.reset_prefix_cache(), timeout=timeout)
except asyncio.TimeoutError:
print(f"[reset_prefix_cache] TIMEOUT replica={idx} after {timeout}s")
logger.warning(f"reset_prefix_cache TIMEOUT replica={idx} after {timeout}s")
return
except Exception as e:
print(f"[reset_prefix_cache] ERROR replica={idx}: {e!r}")
logger.error(f"reset_prefix_cache ERROR replica={idx}: {e!r}")
return
print(f"[reset_prefix_cache] done replica={idx}")
logger.debug(f"reset_prefix_cache done replica={idx}")

tasks = [reset_one(i, replica) for i, replica in enumerate(self.rollout_replicas)]
await asyncio.gather(*tasks, return_exceptions=True)
print("[FullyAsyncAgentLoopManager] Reset prefix cache finished")
logger.info("Reset prefix cache finished")

async def clear_kv_cache(self):
await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas])
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, register
from verl.utils.profiler import simple_timer

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))


@register("partial_single_turn_agent")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from verl.experimental.agent_loop.tool_agent_loop import AgentData, AgentState, ToolAgentLoop
from verl.utils.profiler import simple_timer

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))


@register("async_partial_tool_agent")
Expand Down
6 changes: 3 additions & 3 deletions verl/experimental/fully_async_policy/base_detach_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from verl.utils.device import get_torch_device, is_npu_available
from verl.utils.distributed import stateless_init_process_group

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))


class BaseDetachNcclSync:
Expand Down Expand Up @@ -60,7 +60,7 @@ def get_last_avg_bucket_size_remote(self):
def record_sync_metrics(cls, bucket_size_mb, sync_time):
"""Dynamically adjust the bucket size based on past synchronization times."""
bucket_size_mb_value = bucket_size_mb[0] if isinstance(bucket_size_mb, list) else bucket_size_mb
print(f"[DetachNcclSync] sync_metrics: bucket_size_mb={bucket_size_mb_value:.2f}MB, sync_time={sync_time:.2f}s")
logger.info(f"sync_metrics: bucket_size_mb={bucket_size_mb_value:.2f}MB, sync_time={sync_time:.2f}s")
cls._sync_history.append((bucket_size_mb_value, sync_time))
if len(cls._sync_history) > cls._max_history_size:
cls._sync_history.pop(0)
Expand Down
21 changes: 12 additions & 9 deletions verl/experimental/fully_async_policy/checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import concurrent.futures
import logging
import os
import re
import socket
Expand All @@ -27,6 +28,9 @@
from typing import TYPE_CHECKING, Annotated, Any, TypedDict

import torch

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))
import zmq
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
from ray.util.collective import collective
Expand Down Expand Up @@ -162,7 +166,7 @@ def get_ip() -> str:
return s.getsockname()[0]
except Exception as e: # noqa: BLE001
# fallback to get ip from hostname
print(f"fail to get ip from network interface, fallback to get ip from hostname: {e}")
logger.warning(f"fail to get ip from network interface, fallback to get ip from hostname: {e}")
return socket.gethostbyname(socket.gethostname())


Expand Down Expand Up @@ -309,7 +313,7 @@ def register_checkpoint(
bucket_size = max(
self.device_buffer_size_M << 20, max(_align_size(dtype, shape) for _, shape, dtype in weights_info)
)
print(
logger.info(
f"set checkpoint_engine device buffer size: {self.device_buffer_size_M}M, "
f"and finally set it to {bucket_size >> 20}M considering the largest parameter tensor size"
)
Expand Down Expand Up @@ -365,10 +369,10 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
f"buffer numel {buffer.numel()} should be equal to bucket size {local_buckets[idx].size}"
)
memory_buffers[idx].buffer = buffer
print(
f"[rank{self.current_rank}] register pin_memory for "
f" bucket {idx + 1}/{len(local_buckets)} finished, "
f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
logger.info(
f"[rank{self.current_rank}] register pin_memory for bucket "
f"{idx + 1}/{len(local_buckets)} finished, size "
f"{buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
)
offset = 0
for meta in local_buckets[idx].metas:
Expand Down Expand Up @@ -436,10 +440,9 @@ def update_checkpoint(self, inference_model, group_name: str, overlap_broadcast_
device=get_torch_device().current_device(),
)
except Exception:
print(
logger.error(
"allocate buffer for update_checkpoint failed, "
"you may need to reduce "
"config.async_training.checkpoint_engine.device_buffer_size_M"
"you may need to reduce config.async_training.checkpoint_engine.device_buffer_size_M"
)
raise

Expand Down
11 changes: 8 additions & 3 deletions verl/experimental/fully_async_policy/detach_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import time
from collections import defaultdict
from dataclasses import dataclass
Expand All @@ -19,6 +21,9 @@
import numpy as np
import torch

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))

from verl import DataProto
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput
from verl.trainer.ppo.ray_trainer import compute_response_mask
Expand Down Expand Up @@ -119,7 +124,7 @@ def assemble_batch_from_rollout_samples(
if not rollout_samples:
raise ValueError("Empty rollout_samples provided for batch assembly")

print(f"[BatchUtils] Assembling batch from {len(rollout_samples)} RolloutSample objects")
logger.info(f"Assembling batch from {len(rollout_samples)} RolloutSample objects")

rollout_samples_batch = []
processing_times = []
Expand Down Expand Up @@ -189,7 +194,7 @@ def assemble_batch_from_rollout_samples(
}
)

print(f"[BatchUtils] Batch assembly completed in {time.time() - start_time:.2f}s")
logger.info(f"Batch assembly completed in {time.time() - start_time:.2f}s")

return final_batch

Expand Down Expand Up @@ -322,7 +327,7 @@ def get_aggregated_metrics(self) -> dict[str, Any]:
# Aggregate special metrics
aggregated = self._special_metrics_aggergate(aggregated)

print(f"aggregated metrics done. cost {time.time() - t}")
logger.info(f"aggregated metrics done. cost {time.time() - t:.2f}")

return aggregated

Expand Down
Loading