diff --git a/paddle/cinn/hlir/framework/pir/compilation_task.cc b/paddle/cinn/hlir/framework/pir/compilation_task.cc index 2b8445b280f8a2..67fe83c0950340 100644 --- a/paddle/cinn/hlir/framework/pir/compilation_task.cc +++ b/paddle/cinn/hlir/framework/pir/compilation_task.cc @@ -337,6 +337,7 @@ std::shared_ptr CompilationTask::BuildPirCINNKernelInfo( context_->group_->FuncName() + "_infer_shape", context_->group_->symbol_args_map(), context_->group_->temp_space_sizes()); + VLOG(5) << "Start to compile module into cuda kernel..."; backend_resource->GetBackendCompiler()->SetFusionHash( context_->GetFusionHash()); backend_resource->GetBackendCompiler()->Build(module, diff --git a/paddle/cinn/hlir/framework/pir/fusion_info.cc b/paddle/cinn/hlir/framework/pir/fusion_info.cc index 10f5c33c46b2ba..0514e2735397f3 100644 --- a/paddle/cinn/hlir/framework/pir/fusion_info.cc +++ b/paddle/cinn/hlir/framework/pir/fusion_info.cc @@ -20,6 +20,7 @@ #include "paddle/pir/include/core/ir_printer.h" #include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" PD_DECLARE_bool(enable_cinn_compile_cache); +COMMON_DECLARE_bool(enable_cinn_kernel_cache); namespace cinn::hlir::framework::pir { constexpr static const char* kOpCallStack = "op_callstack"; @@ -30,6 +31,7 @@ static const std::unordered_set kExcludedAttrs = { kOpCallStack, kSymShapeStr, kStructName, kStopGradient}; std::size_t AttributeInfo::hash() const { + if (!FLAGS_enable_cinn_kernel_cache) return attr_.hash(); // Use stable attribute information to calculate hash instead of pointer // addresses std::size_t seed = 0; @@ -57,6 +59,7 @@ std::ostream& operator<<(std::ostream& os, const AttributeInfo& attr_info) { } std::size_t ValueInfo::hash() const { + if (!FLAGS_enable_cinn_kernel_cache) return type_.hash(); // Use stable type information to calculate hash std::size_t seed = 0; @@ -87,14 +90,17 @@ OperationInfo::OperationInfo(const ::pir::Operation& op) { input_infos_.emplace_back(value); } output_infos_.reserve(op.num_results()); - output_infos_symbol_.reserve(op.num_results()); + if (FLAGS_enable_cinn_kernel_cache) + output_infos_symbol_.reserve(op.num_results()); for (const auto value : op.results()) { if (!value || !value.type()) continue; output_infos_.emplace_back(value); - auto& shape_analysis = ::pir::ShapeAnalysisManager::Instance().Get( - const_cast<::pir::Operation&>(op).GetParentProgram()); - output_infos_symbol_.push_back( - shape_analysis.GetShapeOrDataForValue(value)); + if (FLAGS_enable_cinn_kernel_cache) { + auto& shape_analysis = ::pir::ShapeAnalysisManager::Instance().Get( + const_cast<::pir::Operation&>(op).GetParentProgram()); + output_infos_symbol_.push_back( + shape_analysis.GetShapeOrDataForValue(value)); + } } // Keep attributes always in order. const auto& attributes = op.attributes(); @@ -114,8 +120,10 @@ std::size_t OperationInfo::hash() const { hash_combine(seed, info); } for (const auto& info : output_infos_) hash_combine(seed, info); - for (const auto& shape_or_data : output_infos_symbol_) - hash_combine(seed, shape_or_data); + if (FLAGS_enable_cinn_kernel_cache) { + for (const auto& shape_or_data : output_infos_symbol_) + hash_combine(seed, shape_or_data); + } for (const auto& info : attr_infos_) hash_combine(seed, info); return seed; } @@ -254,6 +262,9 @@ std::size_t FusionInfo::hash() const { std::size_t seed = 2153; for (const auto& info : op_infos_) hash_combine(seed, info); for (const auto& dim_expr : input_dim_exprs_) hash_combine(seed, dim_expr); + if (!FLAGS_enable_cinn_kernel_cache) { + hash_combine(seed, *program_info_); + } if (!FLAGS_enable_cinn_compile_cache) hash_combine(seed, unique_fn_name_); return seed; } diff --git a/paddle/cinn/hlir/framework/pir_compiler.cc b/paddle/cinn/hlir/framework/pir_compiler.cc index 288f52edd898c6..0c2a446db9063d 100644 --- a/paddle/cinn/hlir/framework/pir_compiler.cc +++ b/paddle/cinn/hlir/framework/pir_compiler.cc @@ -240,7 +240,6 @@ std::vector PirCompiler::Build( std::string cache_dir = FLAGS_cinn_kernel_cache_save_path + "/" + std::to_string(device_id.value()) + "/" + source_hash; - llvm::sys::fs::create_directories(cache_dir); std::string cache_so_path = cache_dir + "/" + CINN_CACHE_SO; std::string meta_filepath = cache_dir + "/" + CINN_CACHE_META; // Check if .so exists @@ -285,6 +284,9 @@ std::vector PirCompiler::Build( compilation_results[index] = result; } else { + if (FLAGS_enable_cinn_kernel_cache) { + llvm::sys::fs::create_directories(cache_dir); + } // Compilation path compilation_results[index] = Compile(&group_compilation_contexts[index]); diff --git a/paddle/pir/include/core/type.h b/paddle/pir/include/core/type.h index 6e8dd1a27ab9cf..fcfcdeb635b41d 100644 --- a/paddle/pir/include/core/type.h +++ b/paddle/pir/include/core/type.h @@ -125,15 +125,7 @@ class IR_API Type { /// bool IsIntOrIndex() const; bool IsIndex() const; - - std::size_t hash() const { - if (!storage_) return 0; - std::ostringstream oss; - Print(oss); - std::string type_representation = oss.str(); - std::size_t seed = std::hash{}(type_representation); - return seed; - } + std::size_t hash() const { return std::hash()(storage_); } protected: const Storage *storage_{nullptr};