diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index aafdaf4eb137..71752bb19c25 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -23,10 +23,12 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -104,9 +106,124 @@ class CIRCallOpLowering : public mlir::OpConversionPattern { if (mlir::failed( getTypeConverter()->convertTypes(op.getResultTypes(), types))) return mlir::failure(); - rewriter.replaceOpWithNewOp( - op, op.getCalleeAttr(), types, adaptor.getOperands()); - return mlir::LogicalResult::success(); + + if (!op.isIndirect()) { + // Currently variadic functions are not supported by the builtin func + // dialect. For now only basic call to printf are supported by using the + // llvmir dialect. + // TODO: remove this and add support for variadic function calls once + // TODO: supported by the func dialect + if (op.getCallee()->equals_insensitive("printf")) { + SmallVector operandTypes = + llvm::to_vector(adaptor.getOperands().getTypes()); + + // Drop the initial memref operand type (we replace the memref format + // string with equivalent llvm.mlir ops) + operandTypes.erase(operandTypes.begin()); + + // Check that the printf attributes can be used in llvmir dialect (i.e + // they have integer/float type) + if (!llvm::all_of(operandTypes, [](mlir::Type ty) { + return mlir::LLVM::isCompatibleType(ty); + })) { + return op.emitError() + << "lowering of printf attributes having a type that is " + "converted to memref in cir-to-mlir lowering (e.g. " + "pointers) not supported yet"; + } + + // Currently only versions of printf are supported where the format + // string is defined inside the printf ==> the lowering of the cir ops + // will match: + // %global = memref.get_global %frm_str + // %* = memref.reinterpret_cast (%global, 0) + if (auto reinterpret_castOP = + mlir::dyn_cast_or_null( + adaptor.getOperands()[0].getDefiningOp())) { + if (auto getGlobalOp = + mlir::dyn_cast_or_null( + reinterpret_castOP->getOperand(0).getDefiningOp())) { + mlir::ModuleOp parentModule = op->getParentOfType(); + + auto context = rewriter.getContext(); + + // Find the memref.global op defining the frm_str + auto globalOp = parentModule.lookupSymbol( + getGlobalOp.getNameAttr()); + + rewriter.setInsertionPoint(globalOp); + + // Insert a equivalent llvm.mlir.global + auto initialvalueAttr = + mlir::dyn_cast_or_null( + globalOp.getInitialValueAttr()); + + auto type = mlir::LLVM::LLVMArrayType::get( + mlir::IntegerType::get(context, 8), + initialvalueAttr.getNumElements()); + + auto llvmglobalOp = rewriter.create( + globalOp->getLoc(), type, true, mlir::LLVM::Linkage::Internal, + "printf_format_" + globalOp.getSymName().str(), + initialvalueAttr, 0); + + rewriter.setInsertionPoint(getGlobalOp); + + // Insert llvmir dialect ops to retrive the !llvm.ptr of the global + auto globalPtrOp = rewriter.create( + getGlobalOp->getLoc(), llvmglobalOp); + + mlir::Value cst0 = rewriter.create( + getGlobalOp->getLoc(), rewriter.getI8Type(), + rewriter.getIndexAttr(0)); + auto gepPtrOp = rewriter.create( + getGlobalOp->getLoc(), + mlir::LLVM::LLVMPointerType::get(context), + llvmglobalOp.getType(), globalPtrOp, + ArrayRef({cst0, cst0})); + + mlir::ValueRange operands = adaptor.getOperands(); + + // Replace the old memref operand with the !llvm.ptr for the frm_str + mlir::SmallVector newOperands; + newOperands.push_back(gepPtrOp); + newOperands.append(operands.begin() + 1, operands.end()); + + // Create the llvmir dialect function type for printf + auto llvmI32Ty = mlir::IntegerType::get(context, 32); + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(context); + auto llvmFnType = + mlir::LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy, + /*isVarArg=*/true); + + rewriter.setInsertionPoint(op); + + // Insert an llvm.call op with the updated operands to printf + rewriter.replaceOpWithNewOp( + op, llvmFnType, op.getCalleeAttr(), newOperands); + + // Cleanup printf frm_str memref ops + rewriter.eraseOp(reinterpret_castOP); + rewriter.eraseOp(getGlobalOp); + rewriter.eraseOp(globalOp); + + return mlir::LogicalResult::success(); + } + } + + return op.emitError() + << "lowering of printf function with Format-String" + "defined outside of printf is not supported yet"; + } + + rewriter.replaceOpWithNewOp( + op, op.getCalleeAttr(), types, adaptor.getOperands()); + return mlir::LogicalResult::success(); + + } else { + // TODO: support lowering of indirect calls via func.call_indirect op + return op.emitError() << "lowering of indirect calls not supported yet"; + } } }; @@ -557,37 +674,60 @@ class CIRFuncOpLowering : public mlir::OpConversionPattern { mlir::ConversionPatternRewriter &rewriter) const override { auto fnType = op.getFunctionType(); - mlir::TypeConverter::SignatureConversion signatureConversion( - fnType.getNumInputs()); - for (const auto &argType : enumerate(fnType.getInputs())) { - auto convertedType = typeConverter->convertType(argType.value()); - if (!convertedType) - return mlir::failure(); - signatureConversion.addInputs(argType.index(), convertedType); - } + if (fnType.isVarArg()) { + // TODO: once the func dialect supports variadic functions rewrite this + // For now only insert special handling of printf via the llvmir dialect + if (op.getSymName().equals_insensitive("printf")) { + auto context = rewriter.getContext(); + // Create a llvmir dialect function declaration for printf, the + // signature is: i32 (!llvm.ptr, ...) + auto llvmI32Ty = mlir::IntegerType::get(context, 32); + auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(context); + auto llvmFnType = + mlir::LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy, + /*isVarArg=*/true); + auto printfFunc = rewriter.create( + op.getLoc(), "printf", llvmFnType); + rewriter.replaceOp(op, printfFunc); + } else { + rewriter.eraseOp(op); + return op.emitError() << "lowering of variadic functions (except " + "printf) not supported yet"; + } + } else { + mlir::TypeConverter::SignatureConversion signatureConversion( + fnType.getNumInputs()); + + for (const auto &argType : enumerate(fnType.getInputs())) { + auto convertedType = typeConverter->convertType(argType.value()); + if (!convertedType) + return mlir::failure(); + signatureConversion.addInputs(argType.index(), convertedType); + } - SmallVector passThroughAttrs; + SmallVector passThroughAttrs; - if (auto symVisibilityAttr = op.getSymVisibilityAttr()) - passThroughAttrs.push_back( - rewriter.getNamedAttr("sym_visibility", symVisibilityAttr)); + if (auto symVisibilityAttr = op.getSymVisibilityAttr()) + passThroughAttrs.push_back( + rewriter.getNamedAttr("sym_visibility", symVisibilityAttr)); - mlir::Type resultType = - getTypeConverter()->convertType(fnType.getReturnType()); - auto fn = rewriter.create( - op.getLoc(), op.getName(), - rewriter.getFunctionType(signatureConversion.getConvertedTypes(), - resultType ? mlir::TypeRange(resultType) - : mlir::TypeRange()), - passThroughAttrs); + mlir::Type resultType = + getTypeConverter()->convertType(fnType.getReturnType()); + auto fn = rewriter.create( + op.getLoc(), op.getName(), + rewriter.getFunctionType(signatureConversion.getConvertedTypes(), + resultType ? mlir::TypeRange(resultType) + : mlir::TypeRange()), + passThroughAttrs); - if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter, - &signatureConversion))) - return mlir::failure(); - rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end()); + if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter, + &signatureConversion))) + return mlir::failure(); + rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end()); - rewriter.eraseOp(op); + rewriter.eraseOp(op); + } return mlir::LogicalResult::success(); } }; diff --git a/clang/test/CIR/Lowering/ThroughMLIR/call.c b/clang/test/CIR/Lowering/ThroughMLIR/call.c index 3edb7bc83cdc..d130c10435b5 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/call.c +++ b/clang/test/CIR/Lowering/ThroughMLIR/call.c @@ -12,3 +12,41 @@ int test(void) { // CHECK: %[[ARG:.+]] = arith.constant 2 : i32 // CHECK-NEXT: call @foo(%[[ARG]]) : (i32) -> () // CHECK: } + +extern int printf(const char *str, ...); + +// CHECK-LABEL: llvm.func @printf(!llvm.ptr, ...) -> i32 +// CHECK: llvm.mlir.global internal constant @[[FRMT_STR:.*]](dense<[37, 100, 44, 32, 37, 102, 44, 32, 37, 100, 44, 32, 37, 108, 108, 100, 44, 32, 37, 100, 44, 32, 37, 102, 10, 0]> : tensor<26xi8>) {addr_space = 0 : i32} : !llvm.array<26 x i8> + +void testfunc(short s, float X, char C, long long LL, int I, double D) { + printf("%d, %f, %d, %lld, %d, %f\n", s, X, C, LL, I, D); +} + +// CHECK: func.func @testfunc(%[[ARG0:.*]]: i16 {{.*}}, %[[ARG1:.*]]: f32 {{.*}}, %[[ARG2:.*]]: i8 {{.*}}, %[[ARG3:.*]]: i64 {{.*}}, %[[ARG4:.*]]: i32 {{.*}}, %[[ARG5:.*]]: f64 {{.*}}) { +// CHECK: %[[ALLOCA_S:.*]] = memref.alloca() {alignment = 2 : i64} : memref +// CHECK: %[[ALLOCA_X:.*]] = memref.alloca() {alignment = 4 : i64} : memref +// CHECK: %[[ALLOCA_C:.*]] = memref.alloca() {alignment = 1 : i64} : memref +// CHECK: %[[ALLOCA_LL:.*]] = memref.alloca() {alignment = 8 : i64} : memref +// CHECK: %[[ALLOCA_I:.*]] = memref.alloca() {alignment = 4 : i64} : memref +// CHECK: %[[ALLOCA_D:.*]] = memref.alloca() {alignment = 8 : i64} : memref +// CHECK: memref.store %[[ARG0]], %[[ALLOCA_S]][] : memref +// CHECK: memref.store %[[ARG1]], %[[ALLOCA_X]][] : memref +// CHECK: memref.store %[[ARG2]], %[[ALLOCA_C]][] : memref +// CHECK: memref.store %[[ARG3]], %[[ALLOCA_LL]][] : memref +// CHECK: memref.store %[[ARG4]], %[[ALLOCA_I]][] : memref +// CHECK: memref.store %[[ARG5]], %[[ALLOCA_D]][] : memref +// CHECK: %[[FRMT_STR_ADDR:.*]] = llvm.mlir.addressof @[[FRMT_STR]] : !llvm.ptr +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i8 +// CHECK: %[[FRMT_STR_DATA:.*]] = llvm.getelementptr %[[FRMT_STR_ADDR]][%[[C0]], %[[C0]]] : (!llvm.ptr, i8, i8) -> !llvm.ptr, !llvm.array<26 x i8> +// CHECK: %[[S:.*]] = memref.load %[[ALLOCA_S]][] : memref +// CHECK: %[[S_EXT:.*]] = arith.extsi %3 : i16 to i32 +// CHECK: %[[X:.*]] = memref.load %[[ALLOCA_X]][] : memref +// CHECK: %[[X_EXT:.*]] = arith.extf %5 : f32 to f64 +// CHECK: %[[C:.*]] = memref.load %[[ALLOCA_C]][] : memref +// CHECK: %[[C_EXT:.*]] = arith.extsi %7 : i8 to i32 +// CHECK: %[[LL:.*]] = memref.load %[[ALLOCA_LL]][] : memref +// CHECK: %[[I:.*]] = memref.load %[[ALLOCA_I]][] : memref +// CHECK: %[[D:.*]] = memref.load %[[ALLOCA_D]][] : memref +// CHECK: {{.*}} = llvm.call @printf(%[[FRMT_STR_DATA]], %[[S_EXT]], %[[X_EXT]], %[[C_EXT]], %[[LL]], %[[I]], %[[D]]) vararg(!llvm.func) : (!llvm.ptr, i32, f64, i32, i64, i32, f64) -> i32 +// CHECK: return +// CHECK: }