Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/cinn/hlir/framework/pir/compilation_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ std::shared_ptr<pir::CompilationResult> CompilationTask::BuildPirCINNKernelInfo(
context_->group_->FuncName() + "_infer_shape",
context_->group_->symbol_args_map(),
context_->group_->temp_space_sizes());
VLOG(5) << "Start to compile module into cuda kernel...";
backend_resource->GetBackendCompiler()->SetFusionHash(
context_->GetFusionHash());
backend_resource->GetBackendCompiler()->Build(module,
Expand Down
25 changes: 18 additions & 7 deletions paddle/cinn/hlir/framework/pir/fusion_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/pir/include/core/ir_printer.h"
#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h"
PD_DECLARE_bool(enable_cinn_compile_cache);
COMMON_DECLARE_bool(enable_cinn_kernel_cache);
namespace cinn::hlir::framework::pir {

constexpr static const char* kOpCallStack = "op_callstack";
Expand All @@ -30,6 +31,7 @@ static const std::unordered_set<std::string> kExcludedAttrs = {
kOpCallStack, kSymShapeStr, kStructName, kStopGradient};

std::size_t AttributeInfo::hash() const {
if (!FLAGS_enable_cinn_kernel_cache) return attr_.hash();
// Use stable attribute information to calculate hash instead of pointer
// addresses
std::size_t seed = 0;
Expand Down Expand Up @@ -57,6 +59,7 @@ std::ostream& operator<<(std::ostream& os, const AttributeInfo& attr_info) {
}

std::size_t ValueInfo::hash() const {
if (!FLAGS_enable_cinn_kernel_cache) return type_.hash();
// Use stable type information to calculate hash
std::size_t seed = 0;

Expand Down Expand Up @@ -87,14 +90,17 @@ OperationInfo::OperationInfo(const ::pir::Operation& op) {
input_infos_.emplace_back(value);
}
output_infos_.reserve(op.num_results());
output_infos_symbol_.reserve(op.num_results());
if (FLAGS_enable_cinn_kernel_cache)
output_infos_symbol_.reserve(op.num_results());
for (const auto value : op.results()) {
if (!value || !value.type()) continue;
output_infos_.emplace_back(value);
auto& shape_analysis = ::pir::ShapeAnalysisManager::Instance().Get(
const_cast<::pir::Operation&>(op).GetParentProgram());
output_infos_symbol_.push_back(
shape_analysis.GetShapeOrDataForValue(value));
if (FLAGS_enable_cinn_kernel_cache) {
auto& shape_analysis = ::pir::ShapeAnalysisManager::Instance().Get(
const_cast<::pir::Operation&>(op).GetParentProgram());
output_infos_symbol_.push_back(
shape_analysis.GetShapeOrDataForValue(value));
}
}
// Keep attributes always in order.
const auto& attributes = op.attributes();
Expand All @@ -114,8 +120,10 @@ std::size_t OperationInfo::hash() const {
hash_combine(seed, info);
}
for (const auto& info : output_infos_) hash_combine(seed, info);
for (const auto& shape_or_data : output_infos_symbol_)
hash_combine(seed, shape_or_data);
if (FLAGS_enable_cinn_kernel_cache) {
for (const auto& shape_or_data : output_infos_symbol_)
hash_combine(seed, shape_or_data);
}
for (const auto& info : attr_infos_) hash_combine(seed, info);
return seed;
}
Expand Down Expand Up @@ -254,6 +262,9 @@ std::size_t FusionInfo::hash() const {
std::size_t seed = 2153;
for (const auto& info : op_infos_) hash_combine(seed, info);
for (const auto& dim_expr : input_dim_exprs_) hash_combine(seed, dim_expr);
if (!FLAGS_enable_cinn_kernel_cache) {
hash_combine(seed, *program_info_);
}
if (!FLAGS_enable_cinn_compile_cache) hash_combine(seed, unique_fn_name_);
return seed;
}
Expand Down
4 changes: 3 additions & 1 deletion paddle/cinn/hlir/framework/pir_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ std::vector<pir::CINNKernelInfo> PirCompiler::Build(
std::string cache_dir = FLAGS_cinn_kernel_cache_save_path + "/" +
std::to_string(device_id.value()) + "/" +
source_hash;
llvm::sys::fs::create_directories(cache_dir);
std::string cache_so_path = cache_dir + "/" + CINN_CACHE_SO;
std::string meta_filepath = cache_dir + "/" + CINN_CACHE_META;
// Check if .so exists
Expand Down Expand Up @@ -285,6 +284,9 @@ std::vector<pir::CINNKernelInfo> PirCompiler::Build(
compilation_results[index] = result;

} else {
if (FLAGS_enable_cinn_kernel_cache) {
llvm::sys::fs::create_directories(cache_dir);
}
// Compilation path
compilation_results[index] =
Compile(&group_compilation_contexts[index]);
Expand Down
10 changes: 1 addition & 9 deletions paddle/pir/include/core/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,7 @@ class IR_API Type {
///
bool IsIntOrIndex() const;
bool IsIndex() const;

std::size_t hash() const {
if (!storage_) return 0;
std::ostringstream oss;
Print(oss);
std::string type_representation = oss.str();
std::size_t seed = std::hash<std::string>{}(type_representation);
return seed;
}
std::size_t hash() const { return std::hash<const void *>()(storage_); }

protected:
const Storage *storage_{nullptr};
Expand Down
Loading