Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 48 additions & 40 deletions autoparallel/asynctp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
aten = torch.ops.aten
patterns = PatternMatcherPass()

_micro_pipeline_tp_ag_transpose_mm_enabled = True
# Configs:
_ag_transpose_mm_enabled = False
_ag_mm_last_dim_enabled = True
_ag_mm_last_dim_no_splitcat_use = False
_mm_rs_last_dim_enabled = True

# Check performance if overhead of decomposition outweights pipeline wins
_micro_pipeline_tp_ag_mm_last_dim_enabled = True
_micro_pipeline_tp_ag_mm_last_dim_splitcatuse_enabled = True

_micro_pipeline_tp_mm_rs_last_dim_enabled = True
def _is_last_dim(t: torch.Tensor, dim: int) -> bool:
return dim == t.ndim - 1 or dim == -1


def _is_backward(graph: torch.fx.Graph) -> bool:
Expand Down Expand Up @@ -617,7 +619,7 @@ def _find_consumer_matmuls(node: torch.fx.Node) -> list[_Matmul]:
matmul = _ScaledMatmul.from_match([user])
matmuls.append(matmul)
elif (
_micro_pipeline_tp_ag_transpose_mm_enabled
_ag_transpose_mm_enabled
and user.target == aten.permute.default
and (user.args[1] == [1, 0] or user.args[1] == [0, 1])
):
Expand Down Expand Up @@ -762,11 +764,12 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None:
if not is_symm_mem_enabled_for_group(group_name):
return

if (
not _micro_pipeline_tp_ag_mm_last_dim_enabled
and gather_dim == _get_tensor(shard_node).ndim - 1
):
return
if _is_last_dim(_get_tensor(shard_node), gather_dim):
if not _ag_mm_last_dim_enabled:
return

if _get_tensor(shard_node).shape[-1] < 1024:
return

# Find consumer matmuls
matmuls = _find_consumer_matmuls(ag_res_node)
Expand All @@ -784,13 +787,12 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None:
return

if (
_micro_pipeline_tp_ag_mm_last_dim_splitcatuse_enabled
and gather_dim == _get_tensor(shard_node).ndim - 1
_ag_mm_last_dim_no_splitcat_use
and _is_last_dim(_get_tensor(shard_node), gather_dim)
and len(all_gather.res_node.users) > len(matmuls)
):
# The result of ag-split-cat is used not only in matmuls.
# Then it has to be materialized, which can have overhead.
# TODO: find out conditions of strideness when there is no overhead.
log_strs.append(
f"fuse_agmm lastdim ag-split-cat {len(all_gather.res_node.users)} used more than num matmuls"
)
Expand Down Expand Up @@ -837,15 +839,16 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None:
matmul.replace_with(new_out_node)
matmul.erase()
else:
if "val" in shard_node.meta:
restrided = restride_A_shard_for_fused_all_gather_matmul(
_get_tensor(shard_node),
gather_dim,
)
shard_node = graph.call_function(
inductor_prims.force_stride_order,
args=(shard_node, restrided.stride()),
)
if not _is_last_dim(_get_tensor(shard_node), gather_dim):
if "val" in shard_node.meta:
restrided = restride_A_shard_for_fused_all_gather_matmul(
_get_tensor(shard_node),
gather_dim,
)
shard_node = graph.call_function(
inductor_prims.force_stride_order,
args=(shard_node, restrided.stride()),
)
fused_node = _insert_fused_all_gather_matmul(
graph, matmuls, shard_node, gather_dim, group_name
)
Expand Down Expand Up @@ -1055,11 +1058,14 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch, log_strs) ->
log_strs.append("fuse_mmrs not symm mem group")
return

if (
not _micro_pipeline_tp_mm_rs_last_dim_enabled
and orig_scatter_dim == _get_tensor(input_node).ndim - 1
):
return
if _is_last_dim(_get_tensor(input_node), orig_scatter_dim):
if not _mm_rs_last_dim_enabled:
return

group = torch._C._distributed_c10d._resolve_process_group(group_name)
group_size = group.size()
if _get_tensor(input_node).shape[-1] // group_size < 1024:
return

# Currently fused_matmul_reduce_scatter doesn't return the matmul result,
# so we can't apply the fusion if the matmul result is used by multiple
Expand Down Expand Up @@ -1113,16 +1119,17 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch, log_strs) ->

graph = rs_wait_tensor_node.graph
with graph.inserting_before(rs_wait_tensor_node):
# Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter
if "val" in matmul.A_node.meta:
restrided = restride_A_for_fused_matmul_reduce_scatter(
_get_tensor(matmul.A_node),
scatter_dim_after_maybe_reshape,
)
matmul.A_node = graph.call_function(
inductor_prims.force_stride_order,
args=(matmul.A_node, restrided.stride()),
)
if not _is_last_dim(_get_tensor(input_node), orig_scatter_dim):
# Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter
if "val" in matmul.A_node.meta:
restrided = restride_A_for_fused_matmul_reduce_scatter(
_get_tensor(matmul.A_node),
scatter_dim_after_maybe_reshape,
)
matmul.A_node = graph.call_function(
inductor_prims.force_stride_order,
args=(matmul.A_node, restrided.stride()),
)

# Replace matched subgraph with fused matmul reduce scatter node
fused_node = _insert_fused_matmul_reduce_scatter(
Expand Down Expand Up @@ -1295,11 +1302,12 @@ def micro_pipeline_tp_pass(
"async TP found no matching all-gather/reduce-scatter patterns for fusion"
)

for all_gather in all_gathers:
fuse_all_gather_matmul(all_gather, log_strs)

for reduce_scatter in reduce_scatters:
fuse_matmul_reduce_scatter(reduce_scatter, log_strs)

for all_gather in all_gathers:
fuse_all_gather_matmul(all_gather, log_strs)
trace_structured(
"artifact",
metadata_fn=lambda: {
Expand Down
26 changes: 8 additions & 18 deletions autoparallel/asynctp_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,6 @@ def _fused_all_gather_matmul_impl(
group = c10d._resolve_process_group(group_name)

if gather_dim == A_shard.ndim - 1:
# Implementation for gathering on last dimension of matmul (N)
# A_shard splitted column wise
# A_shard: [A0, A1, ... , Ags]
return _fused_all_gather_matmul_last_gather_dim_impl(
mm_out_op,
A_shard,
Expand Down Expand Up @@ -625,11 +622,6 @@ def _fused_all_gather_matmul_last_gather_dim_impl(
def unflatten(t: torch.Tensor) -> torch.Tensor:
return t.view(*leading_dims, -1)

A_out_leading_dims = list(A_shard.shape[:-1])

def unflatten_A_out(t: torch.Tensor) -> torch.Tensor:
return t.view(*A_out_leading_dims, -1)

A_flat_out = A_shard_flat.new_empty(
A_shard_flat.shape[0] * group.size(),
A_shard_flat.shape[1],
Expand All @@ -645,19 +637,17 @@ def unflatten_A_out(t: torch.Tensor) -> torch.Tensor:
for B, out_dtype in zip(Bs, out_dtypes)
]

# Additional allocation for partials output,
# That will be reduced into output.
output_partials = [torch.empty_like(out) for out in outputs]

first = True

def default_consumer(shard: torch.Tensor, rank: int) -> None:
nonlocal first
for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)):
out = outputs[idx] if first else output_partials[idx]
mm_out_op(shard, B_shards[idx][rank], **kwargs, out=out)
if not first:
outputs[idx] += output_partials[idx]
out = outputs[idx]
if first:
torch.ops.aten.mm.out(shard, B_shards[idx][rank], **kwargs, out=out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we prefer using the torch.mm version instead of the torch.ops.aten.mm version? I'm not sure there is effectively a difference, but maybe for consistency?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think there should not be much difference, we can use torch.mm.

else:
out.addmm_(shard, B_shards[idx][rank])

first = False

_pipelined_all_gather_and_consume_last_dim(
Expand All @@ -672,7 +662,7 @@ def default_consumer(shard: torch.Tensor, rank: int) -> None:
# This path is inefficient and will be filtered out at passes stage
# Added only for completness.
A_split_cat_out_flat = torch.cat(A_flat_out.chunk(group_size), dim=-1)
ret_A = unflatten_A_out(A_split_cat_out_flat)
ret_A = unflatten(A_split_cat_out_flat)

return ret_A, [unflatten(output) for output in outputs]

Expand Down Expand Up @@ -1134,7 +1124,7 @@ def _fused_matmul_reduce_scatter_impl(
out_shape = [*A.shape[:-1], B.shape[1]]
out_shape[scatter_dim] //= group.size()

if scatter_dim == A.ndim - 1:
if scatter_dim == A.ndim - 1 or scatter_dim == -1:
B_shards = B.chunk(group.size(), dim=B.ndim - 1)
A_flat = A.flatten(0, -2)

Expand Down
Loading