|
18 | 18 |
|
19 | 19 | #include "tfrt/compiler/stream_analysis.h" |
20 | 20 |
|
| 21 | +#include <optional> |
| 22 | + |
21 | 23 | #include "tfrt/basic_kernels/opdefs/basic_kernels.h" |
22 | 24 | #include "tfrt/basic_kernels/opdefs/types.h" |
23 | 25 | #include "tfrt/compiler/opdefs/tfrt_op_interfaces.h" |
@@ -67,20 +69,49 @@ bool GetMergeInterDependentStreams(mlir::Block& block) { |
67 | 69 | return false; |
68 | 70 | } |
69 | 71 |
|
| 72 | +class DefaultCostModel : public StreamAnalysis::CostModelInterface { |
| 73 | + public: |
| 74 | + std::optional<int64_t> GetOperationCost(mlir::Operation* op) const override { |
| 75 | + // Check if operations defines a cost function. |
| 76 | + if (auto cost_function = mlir::dyn_cast<CostFunctionInterface>(op)) { |
| 77 | + int64_t cost = cost_function.cost(); |
| 78 | + assert(cost > 0 && "cost must be a positive value"); |
| 79 | + return cost; |
| 80 | + } |
| 81 | + |
| 82 | + return std::nullopt; |
| 83 | + } |
| 84 | +}; |
| 85 | + |
| 86 | +const DefaultCostModel* GetDefaultCostModel() { |
| 87 | + static const auto* const default_cost_model = new DefaultCostModel(); |
| 88 | + return default_cost_model; |
| 89 | +} |
| 90 | + |
70 | 91 | } // namespace |
71 | 92 |
|
| 93 | +StreamAnalysis::CostModelInterface::~CostModelInterface() = default; |
| 94 | + |
| 95 | +StreamAnalysis::StreamAnalysis(mlir::func::FuncOp op, |
| 96 | + const CostModelInterface* cost_model) |
| 97 | + : StreamAnalysis(op.front(), cost_model) {} |
| 98 | +StreamAnalysis::StreamAnalysis(mlir::Block& block, |
| 99 | + const CostModelInterface* cost_model) |
| 100 | + : cost_model_(cost_model) { |
| 101 | + if (cost_model_ == nullptr) cost_model_ = GetDefaultCostModel(); |
| 102 | + AnalyzeBlock(block); |
| 103 | +} |
| 104 | + |
72 | 105 | int64_t StreamAnalysis::GetOperationCost(mlir::Operation* op) const { |
73 | 106 | // Root has the lowest cost. |
74 | 107 | if (op == kRootOperation) return 1; |
75 | 108 |
|
76 | 109 | // A few TFRT kernels are guaranteed to be cheap. |
77 | | - if (llvm::isa<ReturnOp, MergeChainsOp>(op)) return 1; |
| 110 | + if (llvm::isa<mlir::func::ReturnOp, ReturnOp, MergeChainsOp>(op)) return 1; |
78 | 111 |
|
79 | | - // Check if operations defines a cost function. |
80 | | - if (auto cost_function = mlir::dyn_cast<CostFunctionInterface>(op)) { |
81 | | - int64_t cost = cost_function.cost(); |
82 | | - assert(cost > 0 && "cost must be a positive value"); |
83 | | - return cost; |
| 112 | + assert(cost_model_); |
| 113 | + if (auto cost = cost_model_->GetOperationCost(op)) { |
| 114 | + return *cost; |
84 | 115 | } |
85 | 116 |
|
86 | 117 | // If there is no cost specified for this operation, We conservatively return |
|
0 commit comments