Skip to content
Prev Previous commit
Next Next commit
add AOAtraceback in shape_propagation
  • Loading branch information
zty-king committed Dec 16, 2025
commit 73df17d30a9fda884eca08ee333bae0934f36b77
253 changes: 145 additions & 108 deletions python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,15 @@ def __init__(
)
self.parser = Parser(tokens)
self.statements = self.parser.parse_program()

if self.traceback and getattr(self.lexer, "final_expressions", None):
final_exprs = self.lexer.final_expressions
if len(final_exprs) == len(self.statements):
for expr, stmt in zip(final_exprs, self.statements):
self.traceback.record_children(
expr, [repr(stmt)], macro_name="parser"
)

if self.aoa_config_reverse:
self.statements = list(reversed(self.statements))
self.input_vars = self.build_input_vars()
Expand Down Expand Up @@ -506,125 +515,153 @@ def _get_var_ref(var):
raise ValueError(f"{var.name} should be assigned before!")

for stmt in self.statements:
stmt_repr = repr(stmt)
left_vars = stmt.left_vars
right_vars = stmt.right_vars
if self.aoa_config_reverse:
left_vars, right_vars = right_vars, left_vars
attrs = stmt.attrs
if len(left_vars) > 1 or len(right_vars) > 1:
if not (len(attrs) == 1 and attrs[0].key == "axis"):
raise ValueError(
"When split/concat, only support one attr named `axis`"
)
axis = attrs[0].value

if len(left_vars) == 1:
in_name = left_vars[0].name
in_ref = _get_var_ref(left_vars[0])
assert in_ref.shape[axis] % len(right_vars) == 0, (
f"when split, the shape of the input tensor {in_name} is {in_ref.shape}, the axis is {axis}, the number of right_vars is {len(right_vars)}, but the shape of the input tensor {in_name} is not divisible by the number of right_vars."
)
sizes = [
in_ref.shape[axis] // len(right_vars)
for var in right_vars
]
result = self.split(in_ref, axis, sizes)
for out_var, out_ref in zip(right_vars, result):
self.intermediate_vars[out_var.name] = out_ref
if (
out_var.name
in self.context.get_all_dst_state_keys()
):
self.output_vars[out_var.name] = out_ref

elif len(right_vars) == 1:
left_refs = [_get_var_ref(var) for var in left_vars]
result = self.concat(left_refs, axis)
out_name = right_vars[0].name
self.intermediate_vars[out_name] = result
if out_name in self.context.get_all_dst_state_keys():
self.output_vars[out_name] = result
try:
if len(left_vars) > 1 or len(right_vars) > 1:
if not (len(attrs) == 1 and attrs[0].key == "axis"):
raise ValueError(
f"When split/concat, only support one attr named `axis`, but got {attrs}."
)
axis = attrs[0].value

else:
raise SyntaxError(
f'Unexpected split/concat statement: {stmt}'
)
if len(left_vars) == 1:
in_name = left_vars[0].name
in_ref = _get_var_ref(left_vars[0])
assert in_ref.shape[axis] % len(right_vars) == 0, (
f"when split, the shape of the input tensor {in_name} is {in_ref.shape}, the axis is {axis}, the number of right_vars is {len(right_vars)}, but the shape of the input tensor {in_name} is not divisible by the number of right_vars."
)
sizes = [
in_ref.shape[axis] // len(right_vars)
for var in right_vars
]
result = self.split(in_ref, axis, sizes)
for out_var, out_ref in zip(right_vars, result):
self.intermediate_vars[out_var.name] = out_ref
if (
out_var.name
in self.context.get_all_dst_state_keys()
):
self.output_vars[out_var.name] = out_ref

elif len(right_vars) == 1:
left_refs = [_get_var_ref(var) for var in left_vars]
result = self.concat(left_refs, axis)
out_name = right_vars[0].name
self.intermediate_vars[out_name] = result
if out_name in self.context.get_all_dst_state_keys():
self.output_vars[out_name] = result

elif len(left_vars) == 1 and len(right_vars) == 1:
lvar, rvar = left_vars[0], right_vars[0]
if rvar.name == "_":
self.need_remove_input_vars.add(lvar.name)
elif lvar.name == "_":
self.need_add_output_vars.add(rvar.name)
else:
if len(attrs) > 0:
assert len(attrs) == 1 or (
len(attrs) == 2
and {attr.key for attr in attrs}
== {"src_dtype", "dst_dtype"}
), (
"Only support:\n"
" - One operator, OR\n"
" - Two operators with keys {'src_dtype', 'dst_dtype'}."
else:
raise SyntaxError(
f'Unexpected split/concat statement: {stmt}'
)
attr = attrs[0]
in_ref = _get_var_ref(lvar)
if attr.key == "permute":
if attr.value == "[]":
ndim = len(in_ref.shape)
perm = str(list(range(ndim - 1, -1, -1)))
else:
perm = attr.value
if self.aoa_config_reverse:
perm = str(
invert_permutation(
ast.literal_eval(perm)
)
)
result = self.transpose(in_ref, perm)
elif attr.key == "dtype":
assert not self.aoa_config_reverse, (
"When `aoa_config_reverse=True`, the dtype must be specified as "
"'src_dtype=...,dst_dtype=...'. Formats like 'dtype=xxx' are not supported."
)
assert attr.value in SUPPORTED_DTYPES, (
f"Unsupported cast dtype: {attr.value}"
)
result = self.cast(in_ref, attr.value)
elif (
attrs[0].key == "src_dtype"
and attrs[1].key == "dst_dtype"
):
src_dtype, dst_dtype = (
attrs[0].value,
attrs[1].value,
)
assert src_dtype in SUPPORTED_DTYPES, (
f"Unsupported cast dtype: {src_dtype}"
)
assert dst_dtype in SUPPORTED_DTYPES, (
f"Unsupported cast dtype: {dst_dtype}"
)
if self.aoa_config_reverse:
src_dtype, dst_dtype = dst_dtype, src_dtype
result = self.cast(in_ref, dst_dtype)
elif attr.key == "axis":
result = in_ref
else:
raise ValueError(f"Unsupported attribute: {attr}")

self.intermediate_vars[rvar.name] = result
if rvar.name in self.context.get_all_dst_state_keys():
self.output_vars[rvar.name] = result
elif len(left_vars) == 1 and len(right_vars) == 1:
lvar, rvar = left_vars[0], right_vars[0]
if rvar.name == "_":
self.need_remove_input_vars.add(lvar.name)
elif lvar.name == "_":
self.need_add_output_vars.add(rvar.name)
else:
# rename operation
in_ref = _get_var_ref(lvar)
result = self.identity(in_ref)
self.intermediate_vars[rvar.name] = result
if rvar.name in self.context.get_all_dst_state_keys():
self.output_vars[rvar.name] = result
else:
raise SyntaxError(f'Unexpected statement: {stmt}')
if len(attrs) > 0:
assert len(attrs) == 1 or (
len(attrs) == 2
and {attr.key for attr in attrs}
== {"src_dtype", "dst_dtype"}
), (
"Only support:\n"
" - One operator, OR\n"
" - Two operators with keys {'src_dtype', 'dst_dtype'}."
)
attr = attrs[0]
in_ref = _get_var_ref(lvar)
if attr.key == "permute":
if attr.value == "[]":
ndim = len(in_ref.shape)
perm = str(list(range(ndim - 1, -1, -1)))
else:
perm = attr.value
if self.aoa_config_reverse:
perm = str(
invert_permutation(
ast.literal_eval(perm)
)
)
result = self.transpose(in_ref, perm)
elif attr.key == "dtype":
assert not self.aoa_config_reverse, (
"When `aoa_config_reverse=True`, the dtype must be specified as "
"'src_dtype=...,dst_dtype=...'. Formats like 'dtype=xxx' are not supported."
)
assert attr.value in SUPPORTED_DTYPES, (
f"Unsupported cast dtype: {attr.value}"
)
result = self.cast(in_ref, attr.value)
elif (
attrs[0].key == "src_dtype"
and attrs[1].key == "dst_dtype"
):
src_dtype, dst_dtype = (
attrs[0].value,
attrs[1].value,
)
assert src_dtype in SUPPORTED_DTYPES, (
f"Unsupported cast dtype: {src_dtype}"
)
assert dst_dtype in SUPPORTED_DTYPES, (
f"Unsupported cast dtype: {dst_dtype}"
)
if self.aoa_config_reverse:
src_dtype, dst_dtype = dst_dtype, src_dtype
result = self.cast(in_ref, dst_dtype)
elif attr.key == "axis":
result = in_ref
else:
raise ValueError(
f"Unsupported attribute: {attr}"
)

self.intermediate_vars[rvar.name] = result
if (
rvar.name
in self.context.get_all_dst_state_keys()
):
self.output_vars[rvar.name] = result
else:
# rename operation
in_ref = _get_var_ref(lvar)
result = self.identity(in_ref)
self.intermediate_vars[rvar.name] = result
if (
rvar.name
in self.context.get_all_dst_state_keys()
):
self.output_vars[rvar.name] = result
else:
raise SyntaxError(f'Unexpected statement: {stmt}')
except (
AssertionError,
ValueError,
KeyError,
SyntaxError,
RuntimeError,
) as e:
if self.traceback:
chain = self.traceback.build_chain(stmt_repr)
self.traceback.add_error(
error_message=str(e),
stage="shape_propagation",
chain=chain,
error_type=type(e).__name__,
)
self.traceback.print()
raise
if self.destination_state_shard_info is not None:
for name in self.destination_state_shard_info:
model_state_key, _ = split_optimizer_state_key(name)
Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/flex_checkpoint/aoa/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def all_tokens(self, expressions):
current_expressions, macro
)

self.final_expressions = list(current_expressions)
tokens = []
for expr in current_expressions:
tokens.extend(self.tokenize(expr))
Expand Down
40 changes: 31 additions & 9 deletions test/flex_checkpoint/test_AOATraceback.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,16 @@ def create_shard_info(keys, shape=(4, 4), dtype="float32"):


class TestMacroLayerOffsetError(unittest.TestCase):
def test_fused_qkv_old_error_chain(self):
def setUp(self):
self.source_keys = [f"model.layers.{i}.weight" for i in range(10)]
self.dest_keys = [f"model.layers.{i}.weight_out" for i in range(10)] + [
f"model.layers.{i}.weight_out2" for i in range(10)
]

self.src_info = create_shard_info(self.source_keys)
self.dst_info = create_shard_info(self.dest_keys)

def test_macro_error_chain(self):
"""
The statement contains fused_qkv_old and is missing a comma, expecting to trigger the assertion and print the chain.
"""
Expand All @@ -46,19 +55,14 @@ def test_fused_qkv_old_error_chain(self):
"enable_traceback": True,
}

source_keys = [f"model.layers.{i}.weight" for i in range(10)]
dest_keys = [f"model.layers.{i}.weight_out" for i in range(10)]

src_info = create_shard_info(source_keys)
dst_info = create_shard_info(dest_keys)

with self.assertRaises(AssertionError):
AOAEngine(
aoa_config=aoa_config,
source_state_shard_info=src_info,
destination_state_shard_info=dst_info,
source_state_shard_info=self.src_info,
destination_state_shard_info=self.dst_info,
)

def test_no_error_should_be_raised(self):
# No error should be raised
source_keys = ["model.layers.0.weight"]
dest_keys = ["model.layers.0.weight_out"]
Expand All @@ -76,6 +80,24 @@ def test_fused_qkv_old_error_chain(self):
destination_state_shard_info=dst_info,
)

def test_shape_propagation_error_chain(self):
"""
when split/concat, only support one attr named `axis`, but got multiple attrs.
"""
aoa_config = {
"aoa_statements": [
"model.layers.0.weight -> model.layers.0.weight_out,model.layers.0.weight_out2,axis=0,axis=1",
],
"enable_traceback": True,
}

with self.assertRaises(ValueError):
AOAEngine(
aoa_config=aoa_config,
source_state_shard_info=self.src_info,
destination_state_shard_info=self.dst_info,
)


if __name__ == "__main__":
unittest.main()
Loading