Skip to content

Commit b635a76

Browse files
cky9301copybara-github
authored andcommitted
Generalize cost model interface in stream analysis
PiperOrigin-RevId: 507822487
1 parent 67d6e6c commit b635a76

File tree

2 files changed

+56
-8
lines changed

2 files changed

+56
-8
lines changed

include/tfrt/compiler/stream_analysis.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
#ifndef TFRT_COMPILER_STREAM_ANALYSIS_H_
5050
#define TFRT_COMPILER_STREAM_ANALYSIS_H_
5151

52+
#include <optional>
53+
5254
#include "llvm/ADT/SetVector.h"
5355
#include "mlir/Dialect/Func/IR/FuncOps.h"
5456
#include "mlir/IR/Block.h"
@@ -119,8 +121,21 @@ class Stream {
119121
// function.
120122
class StreamAnalysis {
121123
public:
122-
explicit StreamAnalysis(mlir::func::FuncOp op) : StreamAnalysis(op.front()) {}
123-
explicit StreamAnalysis(mlir::Block& block) { AnalyzeBlock(block); }
124+
class CostModelInterface {
125+
public:
126+
virtual ~CostModelInterface();
127+
128+
// The implementation is expected to return a positive value or
129+
// std::nullopt if a cost cannot be computed. If std::nullopt is returned,
130+
// stream analysis will use cost threshold as the cost for this op.
131+
virtual std::optional<int64_t> GetOperationCost(
132+
mlir::Operation* op) const = 0;
133+
};
134+
135+
explicit StreamAnalysis(mlir::func::FuncOp op,
136+
const CostModelInterface* cost_model = nullptr);
137+
explicit StreamAnalysis(mlir::Block& block,
138+
const CostModelInterface* cost_model = nullptr);
124139

125140
// Return the stream that contains `op`. An operation can only belong to one
126141
// stream.
@@ -228,6 +243,8 @@ class StreamAnalysis {
228243

229244
// `stream_map_` contains the finalized op-to-stream mapping.
230245
llvm::DenseMap<mlir::Operation*, Stream*> stream_map_;
246+
247+
const CostModelInterface* cost_model_ = nullptr;
231248
};
232249

233250
} // namespace compiler

lib/compiler/stream_analysis.cc

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
#include "tfrt/compiler/stream_analysis.h"
2020

21+
#include <optional>
22+
2123
#include "tfrt/basic_kernels/opdefs/basic_kernels.h"
2224
#include "tfrt/basic_kernels/opdefs/types.h"
2325
#include "tfrt/compiler/opdefs/tfrt_op_interfaces.h"
@@ -67,20 +69,49 @@ bool GetMergeInterDependentStreams(mlir::Block& block) {
6769
return false;
6870
}
6971

72+
class DefaultCostModel : public StreamAnalysis::CostModelInterface {
73+
public:
74+
std::optional<int64_t> GetOperationCost(mlir::Operation* op) const override {
75+
// Check if operations defines a cost function.
76+
if (auto cost_function = mlir::dyn_cast<CostFunctionInterface>(op)) {
77+
int64_t cost = cost_function.cost();
78+
assert(cost > 0 && "cost must be a positive value");
79+
return cost;
80+
}
81+
82+
return std::nullopt;
83+
}
84+
};
85+
86+
const DefaultCostModel* GetDefaultCostModel() {
87+
static const auto* const default_cost_model = new DefaultCostModel();
88+
return default_cost_model;
89+
}
90+
7091
} // namespace
7192

93+
StreamAnalysis::CostModelInterface::~CostModelInterface() = default;
94+
95+
StreamAnalysis::StreamAnalysis(mlir::func::FuncOp op,
96+
const CostModelInterface* cost_model)
97+
: StreamAnalysis(op.front(), cost_model) {}
98+
StreamAnalysis::StreamAnalysis(mlir::Block& block,
99+
const CostModelInterface* cost_model)
100+
: cost_model_(cost_model) {
101+
if (cost_model_ == nullptr) cost_model_ = GetDefaultCostModel();
102+
AnalyzeBlock(block);
103+
}
104+
72105
int64_t StreamAnalysis::GetOperationCost(mlir::Operation* op) const {
73106
// Root has the lowest cost.
74107
if (op == kRootOperation) return 1;
75108

76109
// A few TFRT kernels are guaranteed to be cheap.
77-
if (llvm::isa<ReturnOp, MergeChainsOp>(op)) return 1;
110+
if (llvm::isa<mlir::func::ReturnOp, ReturnOp, MergeChainsOp>(op)) return 1;
78111

79-
// Check if operations defines a cost function.
80-
if (auto cost_function = mlir::dyn_cast<CostFunctionInterface>(op)) {
81-
int64_t cost = cost_function.cost();
82-
assert(cost > 0 && "cost must be a positive value");
83-
return cost;
112+
assert(cost_model_);
113+
if (auto cost = cost_model_->GetOperationCost(op)) {
114+
return *cost;
84115
}
85116

86117
// If there is no cost specified for this operation, We conservatively return

0 commit comments

Comments
 (0)