1616#include " mlir/Dialect/LLVMIR/LLVMTypes.h"
1717#include " mlir/IR/Builders.h"
1818#include " mlir/IR/DialectImplementation.h"
19+ #include " mlir/IR/FunctionImplementation.h"
1920#include " mlir/IR/OpDefinition.h"
2021#include " mlir/IR/OpImplementation.h"
2122#include " mlir/IR/TypeUtilities.h"
@@ -58,6 +59,54 @@ void cir::CIRDialect::initialize() {
5859 addInterfaces<CIROpAsmDialectInterface>();
5960}
6061
62+ // ===----------------------------------------------------------------------===//
63+ // Helpers
64+ // ===----------------------------------------------------------------------===//
65+
66+ // Parses one of the keywords provided in the list `keywords` and returns the
67+ // position of the parsed keyword in the list. If none of the keywords from the
68+ // list is parsed, returns -1.
69+ static int parseOptionalKeywordAlternative (OpAsmParser &parser,
70+ ArrayRef<StringRef> keywords) {
71+ for (auto en : llvm::enumerate (keywords)) {
72+ if (succeeded (parser.parseOptionalKeyword (en.value ())))
73+ return en.index ();
74+ }
75+ return -1 ;
76+ }
77+
78+ namespace {
79+ template <typename Ty>
80+ struct EnumTraits {};
81+
82+ #define REGISTER_ENUM_TYPE (Ty ) \
83+ template <> \
84+ struct EnumTraits <Ty> { \
85+ static StringRef stringify (Ty value) { return stringify##Ty (value); } \
86+ static unsigned getMaxEnumVal () { return getMaxEnumValFor##Ty (); } \
87+ }
88+
89+ REGISTER_ENUM_TYPE (GlobalLinkageKind);
90+ } // namespace
91+
92+ // / Parse an enum from the keyword, or default to the provided default value.
93+ // / The return type is the enum type by default, unless overriden with the
94+ // / second template argument.
95+ // / TODO: teach other places in this file to use this function.
96+ template <typename EnumTy, typename RetTy = EnumTy>
97+ static RetTy parseOptionalCIRKeyword (OpAsmParser &parser,
98+ OperationState &result,
99+ EnumTy defaultValue) {
100+ SmallVector<StringRef, 10 > names;
101+ for (unsigned i = 0 , e = EnumTraits<EnumTy>::getMaxEnumVal (); i <= e; ++i)
102+ names.push_back (EnumTraits<EnumTy>::stringify (static_cast <EnumTy>(i)));
103+
104+ int index = parseOptionalKeywordAlternative (parser, names);
105+ if (index == -1 )
106+ return static_cast <RetTy>(defaultValue);
107+ return static_cast <RetTy>(index);
108+ }
109+
61110// ===----------------------------------------------------------------------===//
62111// ConstantOp
63112// ===----------------------------------------------------------------------===//
@@ -190,7 +239,7 @@ static LogicalResult verify(cir::CastOp castOp) {
190239// ===----------------------------------------------------------------------===//
191240
192241static mlir::LogicalResult checkReturnAndFunction (ReturnOp op,
193- FuncOp function) {
242+ mlir:: FuncOp function) {
194243 // ReturnOps currently only have a single optional operand.
195244 if (op.getNumOperands () > 1 )
196245 return op.emitOpError () << " expects at most 1 return operand" ;
@@ -223,11 +272,11 @@ static mlir::LogicalResult verify(ReturnOp op) {
223272 // Returns can be present in multiple different scopes, get the
224273 // wrapping function and start from there.
225274 auto *fnOp = op->getParentOp ();
226- while (!isa<FuncOp>(fnOp))
275+ while (!isa<mlir:: FuncOp>(fnOp))
227276 fnOp = fnOp->getParentOp ();
228277
229278 // Make sure return types match function return type.
230- if (checkReturnAndFunction (op, cast<FuncOp>(fnOp)).failed ())
279+ if (checkReturnAndFunction (op, cast<mlir:: FuncOp>(fnOp)).failed ())
231280 return failure ();
232281
233282 return success ();
@@ -495,7 +544,7 @@ static LogicalResult verify(ScopeOp op) {
495544
496545static mlir::LogicalResult verify (YieldOp op) {
497546 auto isDominatedByLoopOrSwitch = [](Operation *parentOp) {
498- while (!llvm::isa<FuncOp>(parentOp)) {
547+ while (!llvm::isa<mlir:: FuncOp>(parentOp)) {
499548 if (llvm::isa<cir::SwitchOp, cir::LoopOp>(parentOp))
500549 return true ;
501550 parentOp = parentOp->getParentOp ();
@@ -504,7 +553,7 @@ static mlir::LogicalResult verify(YieldOp op) {
504553 };
505554
506555 auto isDominatedByLoop = [](Operation *parentOp) {
507- while (!llvm::isa<FuncOp>(parentOp)) {
556+ while (!llvm::isa<mlir:: FuncOp>(parentOp)) {
508557 if (llvm::isa<cir::LoopOp>(parentOp))
509558 return true ;
510559 parentOp = parentOp->getParentOp ();
@@ -1013,15 +1062,15 @@ static LogicalResult verify(GlobalOp op) {
10131062 }
10141063
10151064 switch (op.linkage ()) {
1016- case mlir::cir:: GlobalLinkageKind::InternalLinkage:
1017- case mlir::cir:: GlobalLinkageKind::PrivateLinkage:
1065+ case GlobalLinkageKind::InternalLinkage:
1066+ case GlobalLinkageKind::PrivateLinkage:
10181067 if (op.isPublic ())
10191068 return op->emitError ()
10201069 << " public visibility not allowed with '"
10211070 << stringifyGlobalLinkageKind (op.linkage ()) << " ' linkage" ;
10221071 break ;
1023- case mlir::cir:: GlobalLinkageKind::ExternalLinkage:
1024- case mlir::cir:: GlobalLinkageKind::ExternalWeakLinkage:
1072+ case GlobalLinkageKind::ExternalLinkage:
1073+ case GlobalLinkageKind::ExternalWeakLinkage:
10251074 if (op.isPrivate ())
10261075 return op->emitError ()
10271076 << " private visibility not allowed with '"
@@ -1072,6 +1121,144 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
10721121 return success ();
10731122}
10741123
1124+ // ===----------------------------------------------------------------------===//
1125+ // FuncOp
1126+ // ===----------------------------------------------------------------------===//
1127+
1128+ // / Returns the name used for the linkage attribute. This *must* correspond to
1129+ // / the name of the attribute in ODS.
1130+ static StringRef getLinkageAttrName () { return " linkage" ; }
1131+
1132+ void cir::FuncOp::build (OpBuilder &builder, OperationState &result,
1133+ StringRef name, FunctionType type,
1134+ GlobalLinkageKind linkage,
1135+ ArrayRef<NamedAttribute> attrs,
1136+ ArrayRef<DictionaryAttr> argAttrs) {
1137+ result.addRegion ();
1138+ result.addAttribute (SymbolTable::getSymbolAttrName (),
1139+ builder.getStringAttr (name));
1140+ result.addAttribute (getTypeAttrName (), TypeAttr::get (type));
1141+ result.addAttribute (getLinkageAttrName (), GlobalLinkageKindAttr::get (
1142+ builder.getContext (), linkage));
1143+ result.attributes .append (attrs.begin (), attrs.end ());
1144+ if (argAttrs.empty ())
1145+ return ;
1146+
1147+ function_interface_impl::addArgAndResultAttrs (builder, result, argAttrs,
1148+ /* resultAttrs=*/ llvm::None);
1149+ }
1150+
1151+ static ParseResult parseCIRFuncOp (OpAsmParser &parser, OperationState &state) {
1152+ // Default to external linkage if no keyword is provided.
1153+ state.addAttribute (
1154+ getLinkageAttrName (),
1155+ GlobalLinkageKindAttr::get (
1156+ parser.getContext (),
1157+ parseOptionalCIRKeyword<GlobalLinkageKind>(
1158+ parser, state, GlobalLinkageKind::ExternalLinkage)));
1159+
1160+ StringAttr nameAttr;
1161+ SmallVector<OpAsmParser::OperandType, 8 > entryArgs;
1162+ SmallVector<NamedAttrList, 1 > argAttrs;
1163+ SmallVector<Optional<Location>, 1 > argLocations;
1164+ SmallVector<NamedAttrList, 1 > resultAttrs;
1165+ SmallVector<Type, 8 > argTypes;
1166+ SmallVector<Type, 4 > resultTypes;
1167+ auto &builder = parser.getBuilder ();
1168+
1169+ // Parse the name as a symbol.
1170+ if (parser.parseSymbolName (nameAttr, SymbolTable::getSymbolAttrName (),
1171+ state.attributes ))
1172+ return failure ();
1173+
1174+ // Parse the function signature.
1175+ bool isVariadic = false ;
1176+ if (function_interface_impl::parseFunctionSignature (
1177+ parser, /* allowVariadic=*/ false , entryArgs, argTypes, argAttrs,
1178+ argLocations, isVariadic, resultTypes, resultAttrs))
1179+ return failure ();
1180+
1181+ auto fnType = builder.getFunctionType (argTypes, resultTypes);
1182+ state.addAttribute (function_interface_impl::getTypeAttrName (),
1183+ TypeAttr::get (fnType));
1184+
1185+ // If additional attributes are present, parse them.
1186+ if (parser.parseOptionalAttrDictWithKeyword (state.attributes ))
1187+ return failure ();
1188+
1189+ // Add the attributes to the function arguments.
1190+ assert (argAttrs.size () == argTypes.size ());
1191+ assert (resultAttrs.size () == resultTypes.size ());
1192+ function_interface_impl::addArgAndResultAttrs (builder, state, argAttrs,
1193+ resultAttrs);
1194+
1195+ // Parse the optional function body.
1196+ auto *body = state.addRegion ();
1197+ OptionalParseResult result = parser.parseOptionalRegion (
1198+ *body, entryArgs, entryArgs.empty () ? ArrayRef<Type>() : argTypes);
1199+ return failure (result.hasValue () && failed (*result));
1200+ }
1201+
1202+ static void print (cir::FuncOp op, OpAsmPrinter &p) {
1203+ p << ' ' ;
1204+ if (op.linkage () != GlobalLinkageKind::ExternalLinkage)
1205+ p << stringifyGlobalLinkageKind (op.linkage ()) << ' ' ;
1206+
1207+ // Print function name, signature, and control.
1208+ p.printSymbolName (op.sym_name ());
1209+ auto fnType = op.getType ();
1210+ function_interface_impl::printFunctionSignature (p, op, fnType.getInputs (),
1211+ /* isVariadic=*/ false ,
1212+ fnType.getResults ());
1213+ function_interface_impl::printFunctionAttributes (p, op, fnType.getNumInputs (),
1214+ fnType.getNumResults (), {});
1215+
1216+ // Print the body if this is not an external function.
1217+ Region &body = op.body ();
1218+ if (!body.empty ())
1219+ p.printRegion (body, /* printEntryBlockArgs=*/ false ,
1220+ /* printBlockTerminators=*/ true );
1221+ }
1222+
1223+ // Hook for OpTrait::FunctionLike, called after verifying that the 'type'
1224+ // attribute is present. This can check for preconditions of the
1225+ // getNumArguments hook not failing.
1226+ LogicalResult cir::FuncOp::verifyType () {
1227+ auto type = getTypeAttr ().getValue ();
1228+ if (!type.isa <FunctionType>())
1229+ return emitOpError (" requires '" + getTypeAttrName () +
1230+ " ' attribute of function type" );
1231+ if (getType ().getNumResults () > 1 )
1232+ return emitOpError (" cannot have more than one result" );
1233+ return success ();
1234+ }
1235+
1236+ // Verifies linkage types, similar to LLVM:
1237+ // - functions don't have 'common' linkage
1238+ // - external functions have 'external' or 'extern_weak' linkage
1239+ static LogicalResult verify (cir::FuncOp op) {
1240+ if (op.linkage () == cir::GlobalLinkageKind::CommonLinkage)
1241+ return op.emitOpError ()
1242+ << " functions cannot have '"
1243+ << stringifyGlobalLinkageKind (cir::GlobalLinkageKind::CommonLinkage)
1244+ << " ' linkage" ;
1245+
1246+ if (op.isExternal ()) {
1247+ if (op.linkage () != cir::GlobalLinkageKind::ExternalLinkage &&
1248+ op.linkage () != cir::GlobalLinkageKind::ExternalWeakLinkage)
1249+ return op.emitOpError ()
1250+ << " external functions must have '"
1251+ << stringifyGlobalLinkageKind (
1252+ cir::GlobalLinkageKind::ExternalLinkage)
1253+ << " ' or '"
1254+ << stringifyGlobalLinkageKind (
1255+ cir::GlobalLinkageKind::ExternalWeakLinkage)
1256+ << " ' linkage" ;
1257+ return success ();
1258+ }
1259+ return success ();
1260+ }
1261+
10751262// ===----------------------------------------------------------------------===//
10761263// CIR defined traits
10771264// ===----------------------------------------------------------------------===//
0 commit comments