Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Fix hash() when FLAGS_enable_cinn_kernel_cache is false.
  • Loading branch information
YuhanXu committed Dec 15, 2025
commit faf6d5ca127729a6a4d8506f216165f8220e8c65
1 change: 1 addition & 0 deletions paddle/cinn/hlir/framework/pir/compilation_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ std::shared_ptr<pir::CompilationResult> CompilationTask::BuildPirCINNKernelInfo(
context_->group_->FuncName() + "_infer_shape",
context_->group_->symbol_args_map(),
context_->group_->temp_space_sizes());
VLOG(5) << "Start to compile module into cuda kernel...";
backend_resource->GetBackendCompiler()->SetFusionHash(
context_->GetFusionHash());
backend_resource->GetBackendCompiler()->Build(module,
Expand Down
25 changes: 18 additions & 7 deletions paddle/cinn/hlir/framework/pir/fusion_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/pir/include/core/ir_printer.h"
#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h"
PD_DECLARE_bool(enable_cinn_compile_cache);
COMMON_DECLARE_bool(enable_cinn_kernel_cache);
namespace cinn::hlir::framework::pir {

constexpr static const char* kOpCallStack = "op_callstack";
Expand All @@ -30,6 +31,7 @@ static const std::unordered_set<std::string> kExcludedAttrs = {
kOpCallStack, kSymShapeStr, kStructName, kStopGradient};

std::size_t AttributeInfo::hash() const {
if (!FLAGS_enable_cinn_kernel_cache) return attr_.hash();
// Use stable attribute information to calculate hash instead of pointer
// addresses
std::size_t seed = 0;
Expand Down Expand Up @@ -57,6 +59,7 @@ std::ostream& operator<<(std::ostream& os, const AttributeInfo& attr_info) {
}

std::size_t ValueInfo::hash() const {
if (!FLAGS_enable_cinn_kernel_cache) return type_.hash();
// Use stable type information to calculate hash
std::size_t seed = 0;

Expand Down Expand Up @@ -87,14 +90,17 @@ OperationInfo::OperationInfo(const ::pir::Operation& op) {
input_infos_.emplace_back(value);
}
output_infos_.reserve(op.num_results());
output_infos_symbol_.reserve(op.num_results());
if (FLAGS_enable_cinn_kernel_cache)
output_infos_symbol_.reserve(op.num_results());
for (const auto value : op.results()) {
if (!value || !value.type()) continue;
output_infos_.emplace_back(value);
auto& shape_analysis = ::pir::ShapeAnalysisManager::Instance().Get(
const_cast<::pir::Operation&>(op).GetParentProgram());
output_infos_symbol_.push_back(
shape_analysis.GetShapeOrDataForValue(value));
if (FLAGS_enable_cinn_kernel_cache) {
auto& shape_analysis = ::pir::ShapeAnalysisManager::Instance().Get(
const_cast<::pir::Operation&>(op).GetParentProgram());
output_infos_symbol_.push_back(
shape_analysis.GetShapeOrDataForValue(value));
}
}
// Keep attributes always in order.
const auto& attributes = op.attributes();
Expand All @@ -114,8 +120,10 @@ std::size_t OperationInfo::hash() const {
hash_combine(seed, info);
}
for (const auto& info : output_infos_) hash_combine(seed, info);
for (const auto& shape_or_data : output_infos_symbol_)
hash_combine(seed, shape_or_data);
if (FLAGS_enable_cinn_kernel_cache) {
for (const auto& shape_or_data : output_infos_symbol_)
hash_combine(seed, shape_or_data);
}
for (const auto& info : attr_infos_) hash_combine(seed, info);
return seed;
}
Expand Down Expand Up @@ -254,6 +262,9 @@ std::size_t FusionInfo::hash() const {
std::size_t seed = 2153;
for (const auto& info : op_infos_) hash_combine(seed, info);
for (const auto& dim_expr : input_dim_exprs_) hash_combine(seed, dim_expr);
if (!FLAGS_enable_cinn_kernel_cache) {
hash_combine(seed, *program_info_);
}
if (!FLAGS_enable_cinn_compile_cache) hash_combine(seed, unique_fn_name_);
return seed;
}
Expand Down
4 changes: 3 additions & 1 deletion paddle/cinn/hlir/framework/pir_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ std::vector<pir::CINNKernelInfo> PirCompiler::Build(
std::string cache_dir = FLAGS_cinn_kernel_cache_save_path + "/" +
std::to_string(device_id.value()) + "/" +
source_hash;
llvm::sys::fs::create_directories(cache_dir);
std::string cache_so_path = cache_dir + "/" + CINN_CACHE_SO;
std::string meta_filepath = cache_dir + "/" + CINN_CACHE_META;
// Check if .so exists
Expand Down Expand Up @@ -285,6 +284,9 @@ std::vector<pir::CINNKernelInfo> PirCompiler::Build(
compilation_results[index] = result;

} else {
if (FLAGS_enable_cinn_kernel_cache) {
llvm::sys::fs::create_directories(cache_dir);
}
// Compilation path
compilation_results[index] =
Compile(&group_compilation_contexts[index]);
Expand Down
10 changes: 1 addition & 9 deletions paddle/pir/include/core/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,7 @@ class IR_API Type {
///
bool IsIntOrIndex() const;
bool IsIndex() const;

std::size_t hash() const {
if (!storage_) return 0;
std::ostringstream oss;
Print(oss);
std::string type_representation = oss.str();
std::size_t seed = std::hash<std::string>{}(type_representation);
return seed;
}
std::size_t hash() const { return std::hash<const void *>()(storage_); }

protected:
const Storage *storage_{nullptr};
Expand Down
Loading