diff --git a/python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py b/python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py index a408f428a4402d..85dcc3124ef2e7 100644 --- a/python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py +++ b/python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py @@ -26,6 +26,7 @@ from ..dcp.sharded_weight import ShardedWeightDesc from .lexer import Lexer from .parser import Parser +from .traceback import AOATraceback _ShardInfo = dict[str, list[ShardedWeightDesc]] @@ -143,7 +144,7 @@ def get_src_state_shard_num(self, src_state_key: str) -> int: ) assert opt_state_name is None, ( - "AOA notions apply only to the model state, but are automatically propagated to the optimizer state." + "AOA notions apply only to the model state, but are automatically propagated to the optimizer state.Now the src_state_key is {src_state_key}, which is a optimizer state key." ) reverse = True if self.aoa_config_reverse: @@ -178,7 +179,7 @@ def get_src_state_shard_num(self, src_state_key: str) -> int: return 1 if len(shard_nums) > 1: raise AssertionError( - f"Inconsistent shard numbers among keys in source_sharded_state_dict: {shard_nums}." + f"Inconsistent shard numbers among keys in source_sharded_state_dict for the key {src_state_key}: shard_nums={shard_nums}." ) return shard_nums.pop() @@ -191,7 +192,7 @@ def get_dst_state_shard_num(self, dst_state_key: str) -> int: ) assert opt_state_name is None, ( - "AOA notions apply only to the model state, but are automatically propagated to the optimizer state." + "AOA notions apply only to the model state, but are automatically propagated to the optimizer state.Now the dst_state_key is {dst_state_key}, which is a optimizer state key." ) reverse = False if self.aoa_config_reverse: @@ -226,7 +227,7 @@ def get_dst_state_shard_num(self, dst_state_key: str) -> int: return 1 if len(shard_nums) > 1: raise AssertionError( - f"Inconsistent shard numbers among keys in destination_state_shard_info: {shard_nums}." + f"Inconsistent shard numbers among keys in destination_state_shard_info for the key {dst_state_key}: shard_nums={shard_nums}." ) return shard_nums.pop() @@ -252,7 +253,7 @@ def resolve_mapping_chain(self, key: str, reverse: bool = False) -> str: while current_key in mapping_dict: assert current_key not in visited, ( - "Infinite loop detected in resolve_mapping_chain,which means the start key is not src_key or the end key is not dst_key, the aoa_config is error" + f"Infinite loop detected in resolve_mapping_chain, which means the start key is not src_key or the end key is not dst_key, the aoa_config is error. current_key={current_key}, the loop is: {'->'.join(visited)}->{current_key}" ) visited.add(current_key) if reverse and current_key in self.get_all_src_state_keys(): @@ -282,16 +283,28 @@ def __init__( self.aoa_config_reverse = self.aoa_config.get( "aoa_config_reverse", False ) + enable_traceback = self.aoa_config.get("enable_traceback", True) + self.traceback = AOATraceback() if enable_traceback else None self.context = AOAShardInfoContext( source_state_shard_info, destination_state_shard_info, self.aoa_config_reverse, ) - self.lexer = Lexer(self.context) - self.parser = Parser( - self.lexer.all_tokens(self.aoa_config.get("aoa_statements", [])) + self.lexer = Lexer(self.context, traceback=self.traceback) + tokens = self.lexer.all_tokens( + self.aoa_config.get("aoa_statements", []) ) + self.parser = Parser(tokens) self.statements = self.parser.parse_program() + + if self.traceback and getattr(self.lexer, "final_expressions", None): + final_exprs = self.lexer.final_expressions + if len(final_exprs) == len(self.statements): + for expr, stmt in zip(final_exprs, self.statements): + self.traceback.record_children( + expr, [repr(stmt)], macro_name="parser" + ) + if self.aoa_config_reverse: self.statements = list(reversed(self.statements)) self.input_vars = self.build_input_vars() @@ -413,7 +426,7 @@ def concat(self, tensors: list[TensorDesc], axis: int) -> TensorDesc: shape[axis] = sum(t.shape[axis] for t in tensors) dtype = tensors[0].dtype assert all(t.dtype == dtype for t in tensors), ( - "All tensors must have the same dtype!" + f"All tensors must have the same dtype when concatenating multiple tensors!But the tensors {tensors} have different dtypes: {[t.dtype for t in tensors]}." ) curr = 0 for t in tensors: @@ -502,123 +515,153 @@ def _get_var_ref(var): raise ValueError(f"{var.name} should be assigned before!") for stmt in self.statements: + stmt_repr = repr(stmt) left_vars = stmt.left_vars right_vars = stmt.right_vars if self.aoa_config_reverse: left_vars, right_vars = right_vars, left_vars attrs = stmt.attrs - if len(left_vars) > 1 or len(right_vars) > 1: - if not (len(attrs) == 1 and attrs[0].key == "axis"): - raise ValueError( - "When split/concat, only support one attr named `axis`" - ) - axis = attrs[0].value - - if len(left_vars) == 1: - in_name = left_vars[0].name - in_ref = _get_var_ref(left_vars[0]) - assert in_ref.shape[axis] % len(right_vars) == 0 - sizes = [ - in_ref.shape[axis] // len(right_vars) - for var in right_vars - ] - result = self.split(in_ref, axis, sizes) - for out_var, out_ref in zip(right_vars, result): - self.intermediate_vars[out_var.name] = out_ref - if ( - out_var.name - in self.context.get_all_dst_state_keys() - ): - self.output_vars[out_var.name] = out_ref - - elif len(right_vars) == 1: - left_refs = [_get_var_ref(var) for var in left_vars] - result = self.concat(left_refs, axis) - out_name = right_vars[0].name - self.intermediate_vars[out_name] = result - if out_name in self.context.get_all_dst_state_keys(): - self.output_vars[out_name] = result - else: - raise SyntaxError( - f'Unexpected split/concat statement: {stmt}' - ) + try: + if len(left_vars) > 1 or len(right_vars) > 1: + if not (len(attrs) == 1 and attrs[0].key == "axis"): + raise ValueError( + f"When split/concat, only support one attr named `axis`, but got {attrs}." + ) + axis = attrs[0].value - elif len(left_vars) == 1 and len(right_vars) == 1: - lvar, rvar = left_vars[0], right_vars[0] - if rvar.name == "_": - self.need_remove_input_vars.add(lvar.name) - elif lvar.name == "_": - self.need_add_output_vars.add(rvar.name) - else: - if len(attrs) > 0: - assert len(attrs) == 1 or ( - len(attrs) == 2 - and {attr.key for attr in attrs} - == {"src_dtype", "dst_dtype"} - ), ( - "Only support:\n" - " - One operator, OR\n" - " - Two operators with keys {'src_dtype', 'dst_dtype'}." + if len(left_vars) == 1: + in_name = left_vars[0].name + in_ref = _get_var_ref(left_vars[0]) + assert in_ref.shape[axis] % len(right_vars) == 0, ( + f"when split, the shape of the input tensor {in_name} is {in_ref.shape}, the axis is {axis}, the number of right_vars is {len(right_vars)}, but the shape of the input tensor {in_name} is not divisible by the number of right_vars." ) - attr = attrs[0] - in_ref = _get_var_ref(lvar) - if attr.key == "permute": - if attr.value == "[]": - ndim = len(in_ref.shape) - perm = str(list(range(ndim - 1, -1, -1))) - else: - perm = attr.value - if self.aoa_config_reverse: - perm = str( - invert_permutation( - ast.literal_eval(perm) - ) - ) - result = self.transpose(in_ref, perm) - elif attr.key == "dtype": - assert not self.aoa_config_reverse, ( - "When `aoa_config_reverse=True`, the dtype must be specified as " - "'src_dtype=...,dst_dtype=...'. Formats like 'dtype=xxx' are not supported." - ) - assert attr.value in SUPPORTED_DTYPES, ( - f"Unsupported cast dtype: {attr.value}" - ) - result = self.cast(in_ref, attr.value) - elif ( - attrs[0].key == "src_dtype" - and attrs[1].key == "dst_dtype" - ): - src_dtype, dst_dtype = ( - attrs[0].value, - attrs[1].value, - ) - assert src_dtype in SUPPORTED_DTYPES, ( - f"Unsupported cast dtype: {src_dtype}" - ) - assert dst_dtype in SUPPORTED_DTYPES, ( - f"Unsupported cast dtype: {dst_dtype}" - ) - if self.aoa_config_reverse: - src_dtype, dst_dtype = dst_dtype, src_dtype - result = self.cast(in_ref, dst_dtype) - elif attr.key == "axis": - result = in_ref - else: - raise ValueError(f"Unsupported attribute: {attr}") + sizes = [ + in_ref.shape[axis] // len(right_vars) + for var in right_vars + ] + result = self.split(in_ref, axis, sizes) + for out_var, out_ref in zip(right_vars, result): + self.intermediate_vars[out_var.name] = out_ref + if ( + out_var.name + in self.context.get_all_dst_state_keys() + ): + self.output_vars[out_var.name] = out_ref + + elif len(right_vars) == 1: + left_refs = [_get_var_ref(var) for var in left_vars] + result = self.concat(left_refs, axis) + out_name = right_vars[0].name + self.intermediate_vars[out_name] = result + if out_name in self.context.get_all_dst_state_keys(): + self.output_vars[out_name] = result - self.intermediate_vars[rvar.name] = result - if rvar.name in self.context.get_all_dst_state_keys(): - self.output_vars[rvar.name] = result else: - # rename operation - in_ref = _get_var_ref(lvar) - result = self.identity(in_ref) - self.intermediate_vars[rvar.name] = result - if rvar.name in self.context.get_all_dst_state_keys(): - self.output_vars[rvar.name] = result - else: - raise SyntaxError(f'Unexpected statement: {stmt}') + raise SyntaxError( + f'Unexpected split/concat statement: {stmt}' + ) + + elif len(left_vars) == 1 and len(right_vars) == 1: + lvar, rvar = left_vars[0], right_vars[0] + if rvar.name == "_": + self.need_remove_input_vars.add(lvar.name) + elif lvar.name == "_": + self.need_add_output_vars.add(rvar.name) + else: + if len(attrs) > 0: + assert len(attrs) == 1 or ( + len(attrs) == 2 + and {attr.key for attr in attrs} + == {"src_dtype", "dst_dtype"} + ), ( + "Only support:\n" + " - One operator, OR\n" + " - Two operators with keys {'src_dtype', 'dst_dtype'}." + ) + attr = attrs[0] + in_ref = _get_var_ref(lvar) + if attr.key == "permute": + if attr.value == "[]": + ndim = len(in_ref.shape) + perm = str(list(range(ndim - 1, -1, -1))) + else: + perm = attr.value + if self.aoa_config_reverse: + perm = str( + invert_permutation( + ast.literal_eval(perm) + ) + ) + result = self.transpose(in_ref, perm) + elif attr.key == "dtype": + assert not self.aoa_config_reverse, ( + "When `aoa_config_reverse=True`, the dtype must be specified as " + "'src_dtype=...,dst_dtype=...'. Formats like 'dtype=xxx' are not supported." + ) + assert attr.value in SUPPORTED_DTYPES, ( + f"Unsupported cast dtype: {attr.value}" + ) + result = self.cast(in_ref, attr.value) + elif ( + attrs[0].key == "src_dtype" + and attrs[1].key == "dst_dtype" + ): + src_dtype, dst_dtype = ( + attrs[0].value, + attrs[1].value, + ) + assert src_dtype in SUPPORTED_DTYPES, ( + f"Unsupported cast dtype: {src_dtype}" + ) + assert dst_dtype in SUPPORTED_DTYPES, ( + f"Unsupported cast dtype: {dst_dtype}" + ) + if self.aoa_config_reverse: + src_dtype, dst_dtype = dst_dtype, src_dtype + result = self.cast(in_ref, dst_dtype) + elif attr.key == "axis": + result = in_ref + else: + raise ValueError( + f"Unsupported attribute: {attr}" + ) + + self.intermediate_vars[rvar.name] = result + if ( + rvar.name + in self.context.get_all_dst_state_keys() + ): + self.output_vars[rvar.name] = result + else: + # rename operation + in_ref = _get_var_ref(lvar) + result = self.identity(in_ref) + self.intermediate_vars[rvar.name] = result + if ( + rvar.name + in self.context.get_all_dst_state_keys() + ): + self.output_vars[rvar.name] = result + else: + raise SyntaxError(f'Unexpected statement: {stmt}') + except ( + AssertionError, + ValueError, + KeyError, + SyntaxError, + RuntimeError, + ) as e: + if self.traceback: + chain = self.traceback.build_chain(stmt_repr) + self.traceback.add_error( + error_message=str(e), + stage="shape_propagation", + chain=chain, + error_type=type(e).__name__, + ) + self.traceback.print() + raise if self.destination_state_shard_info is not None: for name in self.destination_state_shard_info: model_state_key, _ = split_optimizer_state_key(name) @@ -648,12 +691,17 @@ def _get_var_ref(var): def find_source_slices( self, key: str, local_slice: tuple[slice, ...] ) -> list[SliceRef]: - assert key in self.output_vars + assert key in self.output_vars, ( + f"The key {key} is not in the output_vars (which is built during load_state_dict)." + ) tensor = self.output_vars[key] if tensor is None: return [] results = [] - assert len(local_slice) == len(tensor.shape) + assert len(local_slice) == len(tensor.shape), ( + f"For the key {key}, the target_tensor has {len(local_slice)} dimensions, " + f"but the tensor in output_vars has {len(tensor.shape)} dimensions (shape={tensor.shape}). " + ) ndim = len(tensor.shape) def slice_intersect(a: slice, b: slice): @@ -718,7 +766,9 @@ def find_shard_sources( target_global_shape = target.global_shape if opt_state_name in [".beta1_pow_acc_0", ".beta2_pow_acc_0"]: - assert target_key in self.output_vars + assert target_key in self.output_vars, ( + f"The key {target_key} is not in the output_vars (which is built during load_state_dict)." + ) tensor = self.output_vars[target_key] target_local_shape = tensor.shape target_global_offset = (0,) * len(target_local_shape) @@ -779,7 +829,9 @@ def find_shard_sources( pp_list ), ( "Direct assignment of Tensors with different types is prohibited in AOA. " - "If you want to achieve this functionality, please use the cast semantics provided by AOA." + f"If you want to achieve this functionality, please use the cast semantics provided by AOA. " + f"Now the src_var.dtype is {src_var.dtype}, the target.dtype is {target.dtype}, the pp_list is {pp_list}." + f"The src_key is {src_key}, the target_key is {target.key}." ) else: src_var.dtype = target.dtype diff --git a/python/paddle/distributed/flex_checkpoint/aoa/lexer.py b/python/paddle/distributed/flex_checkpoint/aoa/lexer.py index fa1a78a20eb817..ab0efe82d0c1e8 100644 --- a/python/paddle/distributed/flex_checkpoint/aoa/lexer.py +++ b/python/paddle/distributed/flex_checkpoint/aoa/lexer.py @@ -55,7 +55,7 @@ class Lexer: ('MISMATCH', r'.'), ] - def __init__(self, context): + def __init__(self, context, traceback=None): from .macros import macro_registry self.macros = [list(d.values())[1] for d in macro_registry.macros] @@ -66,6 +66,7 @@ def __init__(self, context): ) ).match self.context = context + self.traceback = traceback def tokenize(self, text): pos = 0 @@ -102,21 +103,47 @@ def apply_macro(self, expression, macro): def apply_single_macro_to_all(self, expressions, macro): new_expressions = [] + macro_name = getattr(macro, "__name__", "macro") for expr in expressions: - results = macro(self.tokenize(expr), expr, self.context) + try: + results = macro(self.tokenize(expr), expr, self.context) + except (AssertionError, ValueError, KeyError, RuntimeError) as e: + if self.traceback: + chain = self.traceback.build_chain(expr) + self.traceback.add_error( + error_message=str(e), + stage=f"{macro_name}", + chain=chain, + error_type=type(e).__name__, + ) + self.traceback.print() + raise + if isinstance(results, str): - new_expressions.append(results) + results_list = [results] else: - new_expressions.extend(results) + results_list = list(results) + + if self.traceback: + if results_list != [expr]: + self.traceback.record_children( + expr, results_list, macro_name + ) + + new_expressions.extend(results_list) return new_expressions def all_tokens(self, expressions): + if self.traceback: + self.traceback.register_roots(list(expressions)) + current_expressions = expressions for macro in self.macros: current_expressions = self.apply_single_macro_to_all( current_expressions, macro ) + self.final_expressions = list(current_expressions) tokens = [] for expr in current_expressions: tokens.extend(self.tokenize(expr)) diff --git a/python/paddle/distributed/flex_checkpoint/aoa/macros.py b/python/paddle/distributed/flex_checkpoint/aoa/macros.py index ba16824b75bca5..c22666731f5b50 100644 --- a/python/paddle/distributed/flex_checkpoint/aoa/macros.py +++ b/python/paddle/distributed/flex_checkpoint/aoa/macros.py @@ -74,7 +74,9 @@ def extract_axis_and_clean_tokens(tokens): axis = int(tokens[idx + 2].value) end_idx = idx + 3 if end_idx < len(tokens) - 1: - assert tokens[end_idx].value == "," + assert tokens[end_idx].value == ",", ( + f"The different attributes must split by a comma, but now the token is {tokens[end_idx].value}." + ) end_idx += 1 tokens = tokens[:idx] + tokens[end_idx:] break @@ -141,7 +143,9 @@ def layer_id_offset_macro(tokens, expression, context): ), None, ) - assert name_with_layer_id_offset, "No $LAYER_ID_OFFSET found in NAME tokens" + assert name_with_layer_id_offset, ( + "No $LAYER_ID_OFFSET found in NAME tokens.Please check the aoa_config." + ) assert all( (t.type != TokenType.IDENTIFIER) or (LAYER_ID_OFFSET_MACRO_TAG in t.value) @@ -197,6 +201,8 @@ def array_macro(tokens, expression, context): and tokens[idx + 2].type == TokenType.COLON and tokens[idx + 3].type == TokenType.NUMBER and tokens[idx + 4].type == TokenType.RBRACKET + ), ( + f"The array macro format is incorrect which is must be like: NAME[START:END], but now the format is {tokens[idx].value}{tokens[idx + 1].value}:{tokens[idx + 3].value}{tokens[idx + 4].value}." ) new_tokens.pop() start = int(tokens[idx + 1].value) @@ -249,16 +255,18 @@ def fused_qkv_old_macro(tokens, expression, context): ): right_var_end_pos = idx + 1 - assert attn_head_num and attn_head_num > 0, "num_heads must be positive." + assert attn_head_num and attn_head_num > 0, ( + f"num_heads must be positive.(got: {attn_head_num})." + ) assert num_key_value_groups and num_key_value_groups > 0, ( - "num_key_value_groups must be positive." + f"num_key_value_groups must be positive.(got: {num_key_value_groups})." ) assert fused_qkv_old_pos is not None, ( - "No fused_qkv_old tag found in expression." + f"No fused_qkv_old tag found in expression. The tag must be {FUSED_QKV_OLD_TAG}." ) assert rarrow_pos is not None, "No -> found in expression." assert attn_head_num % num_key_value_groups == 0, ( - "num_heads must be divisible by num_key_value_groups." + f"num_heads ({attn_head_num}) must be divisible by num_key_value_groups ({num_key_value_groups})." ) results = [] @@ -413,7 +421,9 @@ def fused_ffn_macro(tokens, expression, context): ): fused_ffn_pos = idx assert rarrow_pos is not None, "No -> found in expression." - assert fused_ffn_pos is not None, "No fused_ffn tag found in expression." + assert fused_ffn_pos is not None, ( + f"No fused_ffn tag found in expression. The tag must be {FUSED_FFN_TAG}." + ) results = [] if rarrow_pos == 1: src_ffn_weight_name = tokens[0].value @@ -607,7 +617,9 @@ def fused_qkv_macro(tokens, expression, context): assert num_key_value_groups and num_key_value_groups > 0, ( f"num_key_value_groups must be positive (got: {num_key_value_groups})" ) - assert fused_qkv_pos is not None, "No fused_qkv tag found in expression." + assert fused_qkv_pos is not None, ( + f"No fused_qkv tag found in expression. The tag must be {FUSED_QKV_TAG}." + ) assert rarrow_pos is not None, "No -> found in expression." assert rarrow_pos == 1 or rarrow_pos == 5, ( "Only support q,k,v -> fused_qkv or fused_qkv -> q,k,v patterns" @@ -783,7 +795,7 @@ def id(tokens, expression, context): for token in IDENTIFIER_tokens: assert all(k in token.value for k in valid_keys), ( - f"The token: {token.value} must contain all of the following keys: {valid_keys}" + f"The token: {token.value} must contain all of the following keys: {valid_keys}.When use the id macro all IDENTIFIER tokens must contain the same ID placeholders." ) def dict_cartesian_tuples(d: dict[str, list[int]]): @@ -837,7 +849,7 @@ def get_var_mapping_chain_macro(tokens, expression, context): else: right_var_list.append(extra_suffix_removed_value) assert len(left_var_list) == 1 or len(right_var_list) == 1, ( - "Left or right variable must have the only one element" + "Left or right variable must have the only one element,the aoa_statements not support 'multiple var -> multiple var' pattern." ) if len(left_var_list) == 1: context.left_var_to_right_var_mapping[left_var_list[0]] = right_var_list diff --git a/python/paddle/distributed/flex_checkpoint/aoa/traceback.py b/python/paddle/distributed/flex_checkpoint/aoa/traceback.py new file mode 100644 index 00000000000000..b4f48f80d07713 --- /dev/null +++ b/python/paddle/distributed/flex_checkpoint/aoa/traceback.py @@ -0,0 +1,133 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + + +class AOATraceback: + """ + When error occurs, print the chain of "original aoa_statement -> ... -> current expression". + """ + + def __init__(self) -> None: + self.records: list[dict] = [] + self.last_error_chain: list[str] = [] + self.last_error_message: str = "" + self.last_error_stage: str = "" + self.last_error_type: str = "" + self.parent_map: dict[str, str | None] = {} + self.child_macro_map: dict[str, str] = {} + + def register_roots(self, expressions: list[str]) -> None: + """Register the original aoa_statements as the root nodes of the chain.""" + for expr in expressions: + self.parent_map.setdefault(expr, None) + + def record_children( + self, parent: str, children: list[str], macro_name: str | None = None + ) -> None: + """Record the children expressions obtained by the parent expression, and mark the macro name used.""" + macro = macro_name or "Expanded" + for child in children: + if child == parent: + continue + self.parent_map[child] = parent + self.child_macro_map[child] = macro + + def build_chain(self, expr: str) -> list[str]: + """Build the chain from the root to expr by tracing back from the current expression.""" + chain: list[str] = [] + visited = set() + cur = expr + while cur is not None and cur not in visited: + chain.append(cur) + visited.add(cur) + cur = self.parent_map.get(cur) + chain.reverse() + return chain + + def add_error( + self, + error_message: str, + stage: str, + chain: list[str], + error_type: str = "", + ) -> None: + """Record the error chain and information.""" + self.last_error_chain = chain + self.last_error_message = error_message + self.last_error_stage = stage + self.last_error_type = error_type or "" + self.records.append( + { + "type": "error", + "stage": stage, + "message": error_message, + "error_type": self.last_error_type, + "chain": chain, + } + ) + + def format_traceback(self) -> str: + lines: list[str] = [] + header_text = " AOA Traceback (related chain) " + header = f"===={header_text}====" + footer = "=" * len(header) + + if self.last_error_chain: + lines.append(header) + indent_unit = " " + + lines.append("| Origin AOA Statement") + origin_expr = self.last_error_chain[0].replace("\n", " ") + lines.append(f"|-> {origin_expr}") + + for level, expr in enumerate(self.last_error_chain[1:], start=1): + indent = indent_unit * level + single_line_expr = expr.replace("\n", " ") + macro = self.child_macro_map.get( + expr, self.last_error_stage or "Expanded" + ) + lines.append(f"{indent}| {macro}") + lines.append(f"{indent}|-> {single_line_expr}") + + if self.last_error_message: + err_title = self.last_error_type or "Error" + stage_str = ( + f" [{self.last_error_stage}]" + if self.last_error_stage + else "" + ) + err_level = len(self.last_error_chain) + indent = indent_unit * err_level + single_line_msg = self.last_error_message.replace("\n", " ") + lines.append(f"{indent}| Error") + lines.append( + f"{indent}|-> ({err_title}{stage_str}) {single_line_msg}" + ) + + lines.append(footer) + else: + lines.append(header) + lines.append("(No trace records)") + lines.append(footer) + + return "\n".join(lines) + + def print(self, logger=None) -> None: + text = self.format_traceback() + if logger: + logger.error(text) + else: + print(text) diff --git a/python/paddle/distributed/flex_checkpoint/dcp/full_param.py b/python/paddle/distributed/flex_checkpoint/dcp/full_param.py index 04c8059df164bd..3fd580911bffa2 100644 --- a/python/paddle/distributed/flex_checkpoint/dcp/full_param.py +++ b/python/paddle/distributed/flex_checkpoint/dcp/full_param.py @@ -472,7 +472,7 @@ def _run_single_card( """Simple assembly path for a single GPU.""" for k, v in self.filtered_sharded_state_dict.items(): assert v.local_shape == v.global_shape, ( - "Single card params must not be sharded." + "Single card params must not be sharded.But now the key is {k}, the local_shape is {v.local_shape}, the global_shape is {v.global_shape}." ) for k, shard_mappings in self.destination_sharded_mappings.items(): diff --git a/python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py b/python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py index 87090d165e4d21..ec36442cd1bced 100644 --- a/python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py +++ b/python/paddle/distributed/flex_checkpoint/dcp/load_state_dict.py @@ -223,6 +223,12 @@ def get_rank_to_files( f"Missing keys:{missing_keys}, check whether the checkpoint is complete." ) + unexpected_keys = set(tensor_key_list) - set(state_dict_param_names) + if len(unexpected_keys) > 0: + logger.warning( + f"Unexpected keys:{unexpected_keys}, these keys exist in checkpoint but not in state_dict." + ) + rank_to_files = {} for rank, need_files in enumerate(all_necessary_files): seen = set() @@ -1073,7 +1079,7 @@ def load_state_dict_impl( with paddle.base.dygraph.guard(): global _metadata_manager assert isinstance(state_dict, dict), ( - "The state_dict should be a dictionary." + f"The state_dict should be a dictionary.But now the type is {type(state_dict)}." ) first_key = next(iter(state_dict), None) if isinstance(first_key, tuple): diff --git a/python/paddle/distributed/flex_checkpoint/dcp/save_state_dict.py b/python/paddle/distributed/flex_checkpoint/dcp/save_state_dict.py index 85fcedf77ede4c..2f9c84d06c67fd 100644 --- a/python/paddle/distributed/flex_checkpoint/dcp/save_state_dict.py +++ b/python/paddle/distributed/flex_checkpoint/dcp/save_state_dict.py @@ -188,7 +188,7 @@ def save_state_dict( """ with paddle.base.dygraph.guard(): assert isinstance(state_dict, dict), ( - "The state_dict should be a dictionary." + f"The state_dict should be a dictionary.But now the type is {type(state_dict)}." ) flat_state_dict, mapping = flatten_state_dict(state_dict) if len(flat_state_dict) > 0: diff --git a/test/flex_checkpoint/test_AOATraceback.py b/test/flex_checkpoint/test_AOATraceback.py new file mode 100644 index 00000000000000..3cf5ee496f02fa --- /dev/null +++ b/test/flex_checkpoint/test_AOATraceback.py @@ -0,0 +1,103 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.distributed.flex_checkpoint.aoa.aoa_engine import AOAEngine +from paddle.distributed.flex_checkpoint.dcp.sharded_weight import ( + ShardedWeightDesc, +) + + +def create_shard_info(keys, shape=(4, 4), dtype="float32"): + info = {} + for k in keys: + desc = ShardedWeightDesc( + key=k, + local_shape=shape, + global_shape=shape, + global_offset=(0, 0), + dtype=dtype, + ) + info[k] = [desc] + return info + + +class TestMacroLayerOffsetError(unittest.TestCase): + def setUp(self): + self.source_keys = [f"model.layers.{i}.weight" for i in range(10)] + self.dest_keys = [f"model.layers.{i}.weight_out" for i in range(10)] + [ + f"model.layers.{i}.weight_out2" for i in range(10) + ] + + self.src_info = create_shard_info(self.source_keys) + self.dst_info = create_shard_info(self.dest_keys) + + def test_macro_error_chain(self): + """ + The statement contains fused_qkv_old and is missing a comma, expecting to trigger the assertion and print the chain. + """ + aoa_config = { + "aoa_statements": [ + "model.layers.$LAYER_ID.weight^T -> model.layers.$LAYER_ID.weight_out, axis=0 fused_qkv_old, num_heads=20,num_key_value_groups=4", + ], + "enable_traceback": True, + } + + with self.assertRaises(AssertionError): + AOAEngine( + aoa_config=aoa_config, + source_state_shard_info=self.src_info, + destination_state_shard_info=self.dst_info, + ) + + def test_no_error_should_be_raised(self): + # No error should be raised + source_keys = ["model.layers.0.weight"] + dest_keys = ["model.layers.0.weight_out"] + src_info = create_shard_info(source_keys) + dst_info = create_shard_info(dest_keys) + aoa_config = { + "aoa_statements": [ + "model.layers.0.weight^T -> model.layers.0.weight_out", + ], + "enable_traceback": True, + } + AOAEngine( + aoa_config=aoa_config, + source_state_shard_info=src_info, + destination_state_shard_info=dst_info, + ) + + def test_shape_propagation_error_chain(self): + """ + when split/concat, only support one attr named `axis`, but got multiple attrs. + """ + aoa_config = { + "aoa_statements": [ + "model.layers.0.weight -> model.layers.0.weight_out,model.layers.0.weight_out2,axis=0,axis=1", + ], + "enable_traceback": True, + } + + with self.assertRaises(ValueError): + AOAEngine( + aoa_config=aoa_config, + source_state_shard_info=self.src_info, + destination_state_shard_info=self.dst_info, + ) + + +if __name__ == "__main__": + unittest.main()