-
Notifications
You must be signed in to change notification settings - Fork 5.9k
optimize_wrong_information #76813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
optimize_wrong_information #76813
Changes from all commits
150be1f
cc7ae86
cc82463
c031bab
2bac020
7ed4362
913b00a
73df17d
f6a8bb9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -74,7 +74,9 @@ def extract_axis_and_clean_tokens(tokens): | |
| axis = int(tokens[idx + 2].value) | ||
| end_idx = idx + 3 | ||
| if end_idx < len(tokens) - 1: | ||
| assert tokens[end_idx].value == "," | ||
| assert tokens[end_idx].value == ",", ( | ||
| f"The different attributes must split by a comma, but now the token is {tokens[end_idx].value}." | ||
| ) | ||
| end_idx += 1 | ||
| tokens = tokens[:idx] + tokens[end_idx:] | ||
| break | ||
|
|
@@ -141,7 +143,9 @@ def layer_id_offset_macro(tokens, expression, context): | |
| ), | ||
| None, | ||
| ) | ||
| assert name_with_layer_id_offset, "No $LAYER_ID_OFFSET found in NAME tokens" | ||
| assert name_with_layer_id_offset, ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LAYER_ID,EXPERT_ID等匹配的Macro,如果匹配到了空的结果,往往是不符合预期的,这时希望给一个警告,同时把最原始的AOA statement打印出来,ID什么也没匹配到,可能不符合原始预期。 |
||
| "No $LAYER_ID_OFFSET found in NAME tokens.Please check the aoa_config." | ||
| ) | ||
| assert all( | ||
| (t.type != TokenType.IDENTIFIER) | ||
| or (LAYER_ID_OFFSET_MACRO_TAG in t.value) | ||
|
|
@@ -197,6 +201,8 @@ def array_macro(tokens, expression, context): | |
| and tokens[idx + 2].type == TokenType.COLON | ||
| and tokens[idx + 3].type == TokenType.NUMBER | ||
| and tokens[idx + 4].type == TokenType.RBRACKET | ||
| ), ( | ||
| f"The array macro format is incorrect which is must be like: NAME[START:END], but now the format is {tokens[idx].value}{tokens[idx + 1].value}:{tokens[idx + 3].value}{tokens[idx + 4].value}." | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
| ) | ||
| new_tokens.pop() | ||
| start = int(tokens[idx + 1].value) | ||
|
|
@@ -249,16 +255,18 @@ def fused_qkv_old_macro(tokens, expression, context): | |
| ): | ||
| right_var_end_pos = idx + 1 | ||
|
|
||
| assert attn_head_num and attn_head_num > 0, "num_heads must be positive." | ||
| assert attn_head_num and attn_head_num > 0, ( | ||
| f"num_heads must be positive.(got: {attn_head_num})." | ||
| ) | ||
| assert num_key_value_groups and num_key_value_groups > 0, ( | ||
| "num_key_value_groups must be positive." | ||
| f"num_key_value_groups must be positive.(got: {num_key_value_groups})." | ||
| ) | ||
| assert fused_qkv_old_pos is not None, ( | ||
| "No fused_qkv_old tag found in expression." | ||
| f"No fused_qkv_old tag found in expression. The tag must be {FUSED_QKV_OLD_TAG}." | ||
| ) | ||
| assert rarrow_pos is not None, "No -> found in expression." | ||
| assert attn_head_num % num_key_value_groups == 0, ( | ||
| "num_heads must be divisible by num_key_value_groups." | ||
| f"num_heads ({attn_head_num}) must be divisible by num_key_value_groups ({num_key_value_groups})." | ||
| ) | ||
|
|
||
| results = [] | ||
|
|
@@ -413,7 +421,9 @@ def fused_ffn_macro(tokens, expression, context): | |
| ): | ||
| fused_ffn_pos = idx | ||
| assert rarrow_pos is not None, "No -> found in expression." | ||
| assert fused_ffn_pos is not None, "No fused_ffn tag found in expression." | ||
| assert fused_ffn_pos is not None, ( | ||
| f"No fused_ffn tag found in expression. The tag must be {FUSED_FFN_TAG}." | ||
| ) | ||
| results = [] | ||
| if rarrow_pos == 1: | ||
| src_ffn_weight_name = tokens[0].value | ||
|
|
@@ -607,7 +617,9 @@ def fused_qkv_macro(tokens, expression, context): | |
| assert num_key_value_groups and num_key_value_groups > 0, ( | ||
| f"num_key_value_groups must be positive (got: {num_key_value_groups})" | ||
| ) | ||
| assert fused_qkv_pos is not None, "No fused_qkv tag found in expression." | ||
| assert fused_qkv_pos is not None, ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个也希望把aoa stmt直接打出来 |
||
| f"No fused_qkv tag found in expression. The tag must be {FUSED_QKV_TAG}." | ||
| ) | ||
| assert rarrow_pos is not None, "No -> found in expression." | ||
| assert rarrow_pos == 1 or rarrow_pos == 5, ( | ||
| "Only support q,k,v -> fused_qkv or fused_qkv -> q,k,v patterns" | ||
|
|
@@ -783,7 +795,7 @@ def id(tokens, expression, context): | |
|
|
||
| for token in IDENTIFIER_tokens: | ||
| assert all(k in token.value for k in valid_keys), ( | ||
| f"The token: {token.value} must contain all of the following keys: {valid_keys}" | ||
| f"The token: {token.value} must contain all of the following keys: {valid_keys}.When use the id macro all IDENTIFIER tokens must contain the same ID placeholders." | ||
| ) | ||
|
|
||
| def dict_cartesian_tuples(d: dict[str, list[int]]): | ||
|
|
@@ -837,7 +849,7 @@ def get_var_mapping_chain_macro(tokens, expression, context): | |
| else: | ||
| right_var_list.append(extra_suffix_removed_value) | ||
| assert len(left_var_list) == 1 or len(right_var_list) == 1, ( | ||
| "Left or right variable must have the only one element" | ||
| "Left or right variable must have the only one element,the aoa_statements not support 'multiple var -> multiple var' pattern." | ||
| ) | ||
| if len(left_var_list) == 1: | ||
| context.left_var_to_right_var_mapping[left_var_list[0]] = right_var_list | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
|
|
||
| class AOATraceback: | ||
| """ | ||
| When error occurs, print the chain of "original aoa_statement -> ... -> current expression". | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| self.records: list[dict] = [] | ||
| self.last_error_chain: list[str] = [] | ||
| self.last_error_message: str = "" | ||
| self.last_error_stage: str = "" | ||
| self.last_error_type: str = "" | ||
| self.parent_map: dict[str, str | None] = {} | ||
| self.child_macro_map: dict[str, str] = {} | ||
|
|
||
| def register_roots(self, expressions: list[str]) -> None: | ||
| """Register the original aoa_statements as the root nodes of the chain.""" | ||
| for expr in expressions: | ||
| self.parent_map.setdefault(expr, None) | ||
|
|
||
| def record_children( | ||
| self, parent: str, children: list[str], macro_name: str | None = None | ||
| ) -> None: | ||
| """Record the children expressions obtained by the parent expression, and mark the macro name used.""" | ||
| macro = macro_name or "Expanded" | ||
| for child in children: | ||
| if child == parent: | ||
| continue | ||
| self.parent_map[child] = parent | ||
| self.child_macro_map[child] = macro | ||
|
|
||
| def build_chain(self, expr: str) -> list[str]: | ||
| """Build the chain from the root to expr by tracing back from the current expression.""" | ||
| chain: list[str] = [] | ||
| visited = set() | ||
| cur = expr | ||
| while cur is not None and cur not in visited: | ||
| chain.append(cur) | ||
| visited.add(cur) | ||
| cur = self.parent_map.get(cur) | ||
| chain.reverse() | ||
| return chain | ||
|
|
||
| def add_error( | ||
| self, | ||
| error_message: str, | ||
| stage: str, | ||
| chain: list[str], | ||
| error_type: str = "", | ||
| ) -> None: | ||
| """Record the error chain and information.""" | ||
| self.last_error_chain = chain | ||
| self.last_error_message = error_message | ||
| self.last_error_stage = stage | ||
| self.last_error_type = error_type or "" | ||
| self.records.append( | ||
| { | ||
| "type": "error", | ||
| "stage": stage, | ||
| "message": error_message, | ||
| "error_type": self.last_error_type, | ||
| "chain": chain, | ||
| } | ||
| ) | ||
|
|
||
| def format_traceback(self) -> str: | ||
| lines: list[str] = [] | ||
| header_text = " AOA Traceback (related chain) " | ||
| header = f"===={header_text}====" | ||
| footer = "=" * len(header) | ||
|
|
||
| if self.last_error_chain: | ||
| lines.append(header) | ||
| indent_unit = " " | ||
|
|
||
| lines.append("| Origin AOA Statement") | ||
| origin_expr = self.last_error_chain[0].replace("\n", " ") | ||
| lines.append(f"|-> {origin_expr}") | ||
|
|
||
| for level, expr in enumerate(self.last_error_chain[1:], start=1): | ||
| indent = indent_unit * level | ||
| single_line_expr = expr.replace("\n", " ") | ||
| macro = self.child_macro_map.get( | ||
| expr, self.last_error_stage or "Expanded" | ||
| ) | ||
| lines.append(f"{indent}| {macro}") | ||
| lines.append(f"{indent}|-> {single_line_expr}") | ||
|
|
||
| if self.last_error_message: | ||
| err_title = self.last_error_type or "Error" | ||
| stage_str = ( | ||
| f" [{self.last_error_stage}]" | ||
| if self.last_error_stage | ||
| else "" | ||
| ) | ||
| err_level = len(self.last_error_chain) | ||
| indent = indent_unit * err_level | ||
| single_line_msg = self.last_error_message.replace("\n", " ") | ||
| lines.append(f"{indent}| Error") | ||
| lines.append( | ||
| f"{indent}|-> ({err_title}{stage_str}) {single_line_msg}" | ||
| ) | ||
|
|
||
| lines.append(footer) | ||
| else: | ||
| lines.append(header) | ||
| lines.append("(No trace records)") | ||
| lines.append(footer) | ||
|
|
||
| return "\n".join(lines) | ||
|
|
||
| def print(self, logger=None) -> None: | ||
| text = self.format_traceback() | ||
| if logger: | ||
| logger.error(text) | ||
| else: | ||
| print(text) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -223,6 +223,12 @@ def get_rank_to_files( | |
| f"Missing keys:{missing_keys}, check whether the checkpoint is complete." | ||
| ) | ||
|
|
||
| unexpected_keys = set(tensor_key_list) - set(state_dict_param_names) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议 每个key单独一行 前后加一个醒目的界符 |
||
| if len(unexpected_keys) > 0: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
| logger.warning( | ||
| f"Unexpected keys:{unexpected_keys}, these keys exist in checkpoint but not in state_dict." | ||
| ) | ||
|
|
||
| rank_to_files = {} | ||
| for rank, need_files in enumerate(all_necessary_files): | ||
| seen = set() | ||
|
|
@@ -1073,7 +1079,7 @@ def load_state_dict_impl( | |
| with paddle.base.dygraph.guard(): | ||
| global _metadata_manager | ||
| assert isinstance(state_dict, dict), ( | ||
| "The state_dict should be a dictionary." | ||
| f"The state_dict should be a dictionary.But now the type is {type(state_dict)}." | ||
| ) | ||
| first_key = next(iter(state_dict), None) | ||
| if isinstance(first_key, tuple): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可不可以直接把aoa stmt打印出来