Skip to content
292 changes: 172 additions & 120 deletions python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py

Large diffs are not rendered by default.

35 changes: 31 additions & 4 deletions python/paddle/distributed/flex_checkpoint/aoa/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Lexer:
('MISMATCH', r'.'),
]

def __init__(self, context):
def __init__(self, context, traceback=None):
from .macros import macro_registry

self.macros = [list(d.values())[1] for d in macro_registry.macros]
Expand All @@ -66,6 +66,7 @@ def __init__(self, context):
)
).match
self.context = context
self.traceback = traceback

def tokenize(self, text):
pos = 0
Expand Down Expand Up @@ -102,21 +103,47 @@ def apply_macro(self, expression, macro):

def apply_single_macro_to_all(self, expressions, macro):
new_expressions = []
macro_name = getattr(macro, "__name__", "macro")
for expr in expressions:
results = macro(self.tokenize(expr), expr, self.context)
try:
results = macro(self.tokenize(expr), expr, self.context)
except (AssertionError, ValueError, KeyError, RuntimeError) as e:
if self.traceback:
chain = self.traceback.build_chain(expr)
self.traceback.add_error(
error_message=str(e),
stage=f"{macro_name}",
chain=chain,
error_type=type(e).__name__,
)
self.traceback.print()
raise

if isinstance(results, str):
new_expressions.append(results)
results_list = [results]
else:
new_expressions.extend(results)
results_list = list(results)

if self.traceback:
if results_list != [expr]:
self.traceback.record_children(
expr, results_list, macro_name
)

new_expressions.extend(results_list)
return new_expressions

def all_tokens(self, expressions):
if self.traceback:
self.traceback.register_roots(list(expressions))

current_expressions = expressions
for macro in self.macros:
current_expressions = self.apply_single_macro_to_all(
current_expressions, macro
)

self.final_expressions = list(current_expressions)
tokens = []
for expr in current_expressions:
tokens.extend(self.tokenize(expr))
Expand Down
32 changes: 22 additions & 10 deletions python/paddle/distributed/flex_checkpoint/aoa/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def extract_axis_and_clean_tokens(tokens):
axis = int(tokens[idx + 2].value)
end_idx = idx + 3
if end_idx < len(tokens) - 1:
assert tokens[end_idx].value == ","
assert tokens[end_idx].value == ",", (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可不可以直接把aoa stmt打印出来

f"The different attributes must split by a comma, but now the token is {tokens[end_idx].value}."
)
end_idx += 1
tokens = tokens[:idx] + tokens[end_idx:]
break
Expand Down Expand Up @@ -141,7 +143,9 @@ def layer_id_offset_macro(tokens, expression, context):
),
None,
)
assert name_with_layer_id_offset, "No $LAYER_ID_OFFSET found in NAME tokens"
assert name_with_layer_id_offset, (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LAYER_ID,EXPERT_ID等匹配的Macro,如果匹配到了空的结果,往往是不符合预期的,这时希望给一个警告,同时把最原始的AOA statement打印出来,ID什么也没匹配到,可能不符合原始预期。

"No $LAYER_ID_OFFSET found in NAME tokens.Please check the aoa_config."
)
assert all(
(t.type != TokenType.IDENTIFIER)
or (LAYER_ID_OFFSET_MACRO_TAG in t.value)
Expand Down Expand Up @@ -197,6 +201,8 @@ def array_macro(tokens, expression, context):
and tokens[idx + 2].type == TokenType.COLON
and tokens[idx + 3].type == TokenType.NUMBER
and tokens[idx + 4].type == TokenType.RBRACKET
), (
f"The array macro format is incorrect which is must be like: NAME[START:END], but now the format is {tokens[idx].value}{tokens[idx + 1].value}:{tokens[idx + 3].value}{tokens[idx + 4].value}."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

)
new_tokens.pop()
start = int(tokens[idx + 1].value)
Expand Down Expand Up @@ -249,16 +255,18 @@ def fused_qkv_old_macro(tokens, expression, context):
):
right_var_end_pos = idx + 1

assert attn_head_num and attn_head_num > 0, "num_heads must be positive."
assert attn_head_num and attn_head_num > 0, (
f"num_heads must be positive.(got: {attn_head_num})."
)
assert num_key_value_groups and num_key_value_groups > 0, (
"num_key_value_groups must be positive."
f"num_key_value_groups must be positive.(got: {num_key_value_groups})."
)
assert fused_qkv_old_pos is not None, (
"No fused_qkv_old tag found in expression."
f"No fused_qkv_old tag found in expression. The tag must be {FUSED_QKV_OLD_TAG}."
)
assert rarrow_pos is not None, "No -> found in expression."
assert attn_head_num % num_key_value_groups == 0, (
"num_heads must be divisible by num_key_value_groups."
f"num_heads ({attn_head_num}) must be divisible by num_key_value_groups ({num_key_value_groups})."
)

results = []
Expand Down Expand Up @@ -413,7 +421,9 @@ def fused_ffn_macro(tokens, expression, context):
):
fused_ffn_pos = idx
assert rarrow_pos is not None, "No -> found in expression."
assert fused_ffn_pos is not None, "No fused_ffn tag found in expression."
assert fused_ffn_pos is not None, (
f"No fused_ffn tag found in expression. The tag must be {FUSED_FFN_TAG}."
)
results = []
if rarrow_pos == 1:
src_ffn_weight_name = tokens[0].value
Expand Down Expand Up @@ -607,7 +617,9 @@ def fused_qkv_macro(tokens, expression, context):
assert num_key_value_groups and num_key_value_groups > 0, (
f"num_key_value_groups must be positive (got: {num_key_value_groups})"
)
assert fused_qkv_pos is not None, "No fused_qkv tag found in expression."
assert fused_qkv_pos is not None, (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个也希望把aoa stmt直接打出来

f"No fused_qkv tag found in expression. The tag must be {FUSED_QKV_TAG}."
)
assert rarrow_pos is not None, "No -> found in expression."
assert rarrow_pos == 1 or rarrow_pos == 5, (
"Only support q,k,v -> fused_qkv or fused_qkv -> q,k,v patterns"
Expand Down Expand Up @@ -783,7 +795,7 @@ def id(tokens, expression, context):

for token in IDENTIFIER_tokens:
assert all(k in token.value for k in valid_keys), (
f"The token: {token.value} must contain all of the following keys: {valid_keys}"
f"The token: {token.value} must contain all of the following keys: {valid_keys}.When use the id macro all IDENTIFIER tokens must contain the same ID placeholders."
)

def dict_cartesian_tuples(d: dict[str, list[int]]):
Expand Down Expand Up @@ -837,7 +849,7 @@ def get_var_mapping_chain_macro(tokens, expression, context):
else:
right_var_list.append(extra_suffix_removed_value)
assert len(left_var_list) == 1 or len(right_var_list) == 1, (
"Left or right variable must have the only one element"
"Left or right variable must have the only one element,the aoa_statements not support 'multiple var -> multiple var' pattern."
)
if len(left_var_list) == 1:
context.left_var_to_right_var_mapping[left_var_list[0]] = right_var_list
Expand Down
133 changes: 133 additions & 0 deletions python/paddle/distributed/flex_checkpoint/aoa/traceback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations


class AOATraceback:
"""
When error occurs, print the chain of "original aoa_statement -> ... -> current expression".
"""

def __init__(self) -> None:
self.records: list[dict] = []
self.last_error_chain: list[str] = []
self.last_error_message: str = ""
self.last_error_stage: str = ""
self.last_error_type: str = ""
self.parent_map: dict[str, str | None] = {}
self.child_macro_map: dict[str, str] = {}

def register_roots(self, expressions: list[str]) -> None:
"""Register the original aoa_statements as the root nodes of the chain."""
for expr in expressions:
self.parent_map.setdefault(expr, None)

def record_children(
self, parent: str, children: list[str], macro_name: str | None = None
) -> None:
"""Record the children expressions obtained by the parent expression, and mark the macro name used."""
macro = macro_name or "Expanded"
for child in children:
if child == parent:
continue
self.parent_map[child] = parent
self.child_macro_map[child] = macro

def build_chain(self, expr: str) -> list[str]:
"""Build the chain from the root to expr by tracing back from the current expression."""
chain: list[str] = []
visited = set()
cur = expr
while cur is not None and cur not in visited:
chain.append(cur)
visited.add(cur)
cur = self.parent_map.get(cur)
chain.reverse()
return chain

def add_error(
self,
error_message: str,
stage: str,
chain: list[str],
error_type: str = "",
) -> None:
"""Record the error chain and information."""
self.last_error_chain = chain
self.last_error_message = error_message
self.last_error_stage = stage
self.last_error_type = error_type or ""
self.records.append(
{
"type": "error",
"stage": stage,
"message": error_message,
"error_type": self.last_error_type,
"chain": chain,
}
)

def format_traceback(self) -> str:
lines: list[str] = []
header_text = " AOA Traceback (related chain) "
header = f"===={header_text}===="
footer = "=" * len(header)

if self.last_error_chain:
lines.append(header)
indent_unit = " "

lines.append("| Origin AOA Statement")
origin_expr = self.last_error_chain[0].replace("\n", " ")
lines.append(f"|-> {origin_expr}")

for level, expr in enumerate(self.last_error_chain[1:], start=1):
indent = indent_unit * level
single_line_expr = expr.replace("\n", " ")
macro = self.child_macro_map.get(
expr, self.last_error_stage or "Expanded"
)
lines.append(f"{indent}| {macro}")
lines.append(f"{indent}|-> {single_line_expr}")

if self.last_error_message:
err_title = self.last_error_type or "Error"
stage_str = (
f" [{self.last_error_stage}]"
if self.last_error_stage
else ""
)
err_level = len(self.last_error_chain)
indent = indent_unit * err_level
single_line_msg = self.last_error_message.replace("\n", " ")
lines.append(f"{indent}| Error")
lines.append(
f"{indent}|-> ({err_title}{stage_str}) {single_line_msg}"
)

lines.append(footer)
else:
lines.append(header)
lines.append("(No trace records)")
lines.append(footer)

return "\n".join(lines)

def print(self, logger=None) -> None:
text = self.format_traceback()
if logger:
logger.error(text)
else:
print(text)
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def _run_single_card(
"""Simple assembly path for a single GPU."""
for k, v in self.filtered_sharded_state_dict.items():
assert v.local_shape == v.global_shape, (
"Single card params must not be sharded."
"Single card params must not be sharded.But now the key is {k}, the local_shape is {v.local_shape}, the global_shape is {v.global_shape}."
)

for k, shard_mappings in self.destination_sharded_mappings.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,12 @@ def get_rank_to_files(
f"Missing keys:{missing_keys}, check whether the checkpoint is complete."
)

unexpected_keys = set(tensor_key_list) - set(state_dict_param_names)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议 每个key单独一行 前后加一个醒目的界符

if len(unexpected_keys) > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

logger.warning(
f"Unexpected keys:{unexpected_keys}, these keys exist in checkpoint but not in state_dict."
)

rank_to_files = {}
for rank, need_files in enumerate(all_necessary_files):
seen = set()
Expand Down Expand Up @@ -1073,7 +1079,7 @@ def load_state_dict_impl(
with paddle.base.dygraph.guard():
global _metadata_manager
assert isinstance(state_dict, dict), (
"The state_dict should be a dictionary."
f"The state_dict should be a dictionary.But now the type is {type(state_dict)}."
)
first_key = next(iter(state_dict), None)
if isinstance(first_key, tuple):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def save_state_dict(
"""
with paddle.base.dygraph.guard():
assert isinstance(state_dict, dict), (
"The state_dict should be a dictionary."
f"The state_dict should be a dictionary.But now the type is {type(state_dict)}."
)
flat_state_dict, mapping = flatten_state_dict(state_dict)
if len(flat_state_dict) > 0:
Expand Down
Loading
Loading