diff --git a/autoparallel/asynctp.py b/autoparallel/asynctp.py index 7cec3af..712634b 100644 --- a/autoparallel/asynctp.py +++ b/autoparallel/asynctp.py @@ -33,13 +33,15 @@ aten = torch.ops.aten patterns = PatternMatcherPass() -_micro_pipeline_tp_ag_transpose_mm_enabled = True +# Configs: +_ag_transpose_mm_enabled = False +_ag_mm_last_dim_enabled = True +_ag_mm_last_dim_no_splitcat_use = False +_mm_rs_last_dim_enabled = True -# Check performance if overhead of decomposition outweights pipeline wins -_micro_pipeline_tp_ag_mm_last_dim_enabled = True -_micro_pipeline_tp_ag_mm_last_dim_splitcatuse_enabled = True -_micro_pipeline_tp_mm_rs_last_dim_enabled = True +def _is_last_dim(t: torch.Tensor, dim: int) -> bool: + return dim == t.ndim - 1 or dim == -1 def _is_backward(graph: torch.fx.Graph) -> bool: @@ -617,7 +619,7 @@ def _find_consumer_matmuls(node: torch.fx.Node) -> list[_Matmul]: matmul = _ScaledMatmul.from_match([user]) matmuls.append(matmul) elif ( - _micro_pipeline_tp_ag_transpose_mm_enabled + _ag_transpose_mm_enabled and user.target == aten.permute.default and (user.args[1] == [1, 0] or user.args[1] == [0, 1]) ): @@ -762,11 +764,12 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None: if not is_symm_mem_enabled_for_group(group_name): return - if ( - not _micro_pipeline_tp_ag_mm_last_dim_enabled - and gather_dim == _get_tensor(shard_node).ndim - 1 - ): - return + if _is_last_dim(_get_tensor(shard_node), gather_dim): + if not _ag_mm_last_dim_enabled: + return + + if _get_tensor(shard_node).shape[-1] < 1024: + return # Find consumer matmuls matmuls = _find_consumer_matmuls(ag_res_node) @@ -784,13 +787,12 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None: return if ( - _micro_pipeline_tp_ag_mm_last_dim_splitcatuse_enabled - and gather_dim == _get_tensor(shard_node).ndim - 1 + _ag_mm_last_dim_no_splitcat_use + and _is_last_dim(_get_tensor(shard_node), gather_dim) and len(all_gather.res_node.users) > len(matmuls) ): # The result of ag-split-cat is used not only in matmuls. # Then it has to be materialized, which can have overhead. - # TODO: find out conditions of strideness when there is no overhead. log_strs.append( f"fuse_agmm lastdim ag-split-cat {len(all_gather.res_node.users)} used more than num matmuls" ) @@ -837,15 +839,16 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None: matmul.replace_with(new_out_node) matmul.erase() else: - if "val" in shard_node.meta: - restrided = restride_A_shard_for_fused_all_gather_matmul( - _get_tensor(shard_node), - gather_dim, - ) - shard_node = graph.call_function( - inductor_prims.force_stride_order, - args=(shard_node, restrided.stride()), - ) + if not _is_last_dim(_get_tensor(shard_node), gather_dim): + if "val" in shard_node.meta: + restrided = restride_A_shard_for_fused_all_gather_matmul( + _get_tensor(shard_node), + gather_dim, + ) + shard_node = graph.call_function( + inductor_prims.force_stride_order, + args=(shard_node, restrided.stride()), + ) fused_node = _insert_fused_all_gather_matmul( graph, matmuls, shard_node, gather_dim, group_name ) @@ -1055,11 +1058,14 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch, log_strs) -> log_strs.append("fuse_mmrs not symm mem group") return - if ( - not _micro_pipeline_tp_mm_rs_last_dim_enabled - and orig_scatter_dim == _get_tensor(input_node).ndim - 1 - ): - return + if _is_last_dim(_get_tensor(input_node), orig_scatter_dim): + if not _mm_rs_last_dim_enabled: + return + + group = torch._C._distributed_c10d._resolve_process_group(group_name) + group_size = group.size() + if _get_tensor(input_node).shape[-1] // group_size < 1024: + return # Currently fused_matmul_reduce_scatter doesn't return the matmul result, # so we can't apply the fusion if the matmul result is used by multiple @@ -1113,16 +1119,17 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch, log_strs) -> graph = rs_wait_tensor_node.graph with graph.inserting_before(rs_wait_tensor_node): - # Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter - if "val" in matmul.A_node.meta: - restrided = restride_A_for_fused_matmul_reduce_scatter( - _get_tensor(matmul.A_node), - scatter_dim_after_maybe_reshape, - ) - matmul.A_node = graph.call_function( - inductor_prims.force_stride_order, - args=(matmul.A_node, restrided.stride()), - ) + if not _is_last_dim(_get_tensor(input_node), orig_scatter_dim): + # Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter + if "val" in matmul.A_node.meta: + restrided = restride_A_for_fused_matmul_reduce_scatter( + _get_tensor(matmul.A_node), + scatter_dim_after_maybe_reshape, + ) + matmul.A_node = graph.call_function( + inductor_prims.force_stride_order, + args=(matmul.A_node, restrided.stride()), + ) # Replace matched subgraph with fused matmul reduce scatter node fused_node = _insert_fused_matmul_reduce_scatter( @@ -1295,11 +1302,12 @@ def micro_pipeline_tp_pass( "async TP found no matching all-gather/reduce-scatter patterns for fusion" ) + for all_gather in all_gathers: + fuse_all_gather_matmul(all_gather, log_strs) + for reduce_scatter in reduce_scatters: fuse_matmul_reduce_scatter(reduce_scatter, log_strs) - for all_gather in all_gathers: - fuse_all_gather_matmul(all_gather, log_strs) trace_structured( "artifact", metadata_fn=lambda: { diff --git a/autoparallel/asynctp_ops.py b/autoparallel/asynctp_ops.py index 6d9db73..3fcd924 100644 --- a/autoparallel/asynctp_ops.py +++ b/autoparallel/asynctp_ops.py @@ -418,9 +418,6 @@ def _fused_all_gather_matmul_impl( group = c10d._resolve_process_group(group_name) if gather_dim == A_shard.ndim - 1: - # Implementation for gathering on last dimension of matmul (N) - # A_shard splitted column wise - # A_shard: [A0, A1, ... , Ags] return _fused_all_gather_matmul_last_gather_dim_impl( mm_out_op, A_shard, @@ -625,11 +622,6 @@ def _fused_all_gather_matmul_last_gather_dim_impl( def unflatten(t: torch.Tensor) -> torch.Tensor: return t.view(*leading_dims, -1) - A_out_leading_dims = list(A_shard.shape[:-1]) - - def unflatten_A_out(t: torch.Tensor) -> torch.Tensor: - return t.view(*A_out_leading_dims, -1) - A_flat_out = A_shard_flat.new_empty( A_shard_flat.shape[0] * group.size(), A_shard_flat.shape[1], @@ -645,19 +637,17 @@ def unflatten_A_out(t: torch.Tensor) -> torch.Tensor: for B, out_dtype in zip(Bs, out_dtypes) ] - # Additional allocation for partials output, - # That will be reduced into output. - output_partials = [torch.empty_like(out) for out in outputs] - first = True def default_consumer(shard: torch.Tensor, rank: int) -> None: nonlocal first for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): - out = outputs[idx] if first else output_partials[idx] - mm_out_op(shard, B_shards[idx][rank], **kwargs, out=out) - if not first: - outputs[idx] += output_partials[idx] + out = outputs[idx] + if first: + torch.ops.aten.mm.out(shard, B_shards[idx][rank], **kwargs, out=out) + else: + out.addmm_(shard, B_shards[idx][rank]) + first = False _pipelined_all_gather_and_consume_last_dim( @@ -672,7 +662,7 @@ def default_consumer(shard: torch.Tensor, rank: int) -> None: # This path is inefficient and will be filtered out at passes stage # Added only for completness. A_split_cat_out_flat = torch.cat(A_flat_out.chunk(group_size), dim=-1) - ret_A = unflatten_A_out(A_split_cat_out_flat) + ret_A = unflatten(A_split_cat_out_flat) return ret_A, [unflatten(output) for output in outputs] @@ -1134,7 +1124,7 @@ def _fused_matmul_reduce_scatter_impl( out_shape = [*A.shape[:-1], B.shape[1]] out_shape[scatter_dim] //= group.size() - if scatter_dim == A.ndim - 1: + if scatter_dim == A.ndim - 1 or scatter_dim == -1: B_shards = B.chunk(group.size(), dim=B.ndim - 1) A_flat = A.flatten(0, -2)