Skip to content

Commit 3af9cb5

Browse files
bcardosolopeslanza
authored andcommitted
[CIR] Add cir::FuncOp operation
This is necessary to enforce the presence of linkage kind, which is key to continue implementing methods. While here add some helpers to apply more reliable parsing for linkage types. CodeGen does not use cir::FuncOp just yet, patch a few places that to keep using mlir::FuncOp for now.
1 parent 9466122 commit 3af9cb5

File tree

3 files changed

+295
-13
lines changed

3 files changed

+295
-13
lines changed

mlir/include/mlir/Dialect/CIR/IR/CIROps.td

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
include "mlir/Dialect/CIR/IR/CIRDialect.td"
1818
include "mlir/Dialect/CIR/IR/CIRTypes.td"
1919
include "mlir/Dialect/CIR/IR/CIRAttrs.td"
20-
include "mlir/IR/SymbolInterfaces.td"
20+
include "mlir/Interfaces/CallInterfaces.td"
2121
include "mlir/Interfaces/ControlFlowInterfaces.td"
2222
include "mlir/Interfaces/LoopLikeInterface.td"
2323
include "mlir/Interfaces/InferTypeOpInterface.td"
2424
include "mlir/Interfaces/SideEffectInterfaces.td"
25+
include "mlir/IR/FunctionInterfaces.td"
2526
include "mlir/IR/SymbolInterfaces.td"
2627

2728
//===----------------------------------------------------------------------===//
@@ -327,7 +328,7 @@ def StoreOp : CIR_Op<"store", [
327328
// ReturnOp
328329
//===----------------------------------------------------------------------===//
329330

330-
def ReturnOp : CIR_Op<"return", [HasParent<"FuncOp, ScopeOp, IfOp, SwitchOp, LoopOp">,
331+
def ReturnOp : CIR_Op<"return", [HasParent<"mlir::FuncOp, ScopeOp, IfOp, SwitchOp, LoopOp">,
331332
Terminator]> {
332333
let summary = "Return from function";
333334
let description = [{
@@ -1106,4 +1107,98 @@ def StructElementAddr : CIR_Op<"struct_element_addr"> {
11061107
// FIXME: add verifier.
11071108
}
11081109

1109-
#endif // MLIR_CIR_DIALECT_CIR_OPS
1110+
//===----------------------------------------------------------------------===//
1111+
// FuncOp
1112+
//===----------------------------------------------------------------------===//
1113+
1114+
def FuncOp : CIR_Op<"func", [
1115+
AutomaticAllocationScope, CallableOpInterface, FunctionOpInterface,
1116+
IsolatedFromAbove, Symbol
1117+
]> {
1118+
let summary = "Declare or define a function";
1119+
let description = [{
1120+
1121+
Similar to `mlir::FuncOp` built-in:
1122+
> Operations within the function cannot implicitly capture values defined
1123+
> outside of the function, i.e. Functions are `IsolatedFromAbove`. All
1124+
> external references must use function arguments or attributes that establish
1125+
> a symbolic connection (e.g. symbols referenced by name via a string
1126+
> attribute like SymbolRefAttr). An external function declaration (used when
1127+
> referring to a function declared in some other module) has no body. While
1128+
> the MLIR textual form provides a nice inline syntax for function arguments,
1129+
> they are internally represented as “block arguments” to the first block in
1130+
> the region.
1131+
>
1132+
> Only dialect attribute names may be specified in the attribute dictionaries
1133+
> for function arguments, results, or the function itself.
1134+
1135+
The function linkage information is specified by `linkage`, as defined by
1136+
`GlobalLinkageKind` attribute.
1137+
1138+
Example:
1139+
1140+
```mlir
1141+
// External function definitions.
1142+
func @abort()
1143+
1144+
// A function with internal linkage.
1145+
func internal @count(%x: i64) -> (i64)
1146+
return %x : i64
1147+
}
1148+
```
1149+
}];
1150+
1151+
let arguments = (ins SymbolNameAttr:$sym_name,
1152+
TypeAttr:$type,
1153+
DefaultValuedAttr<GlobalLinkageKind,
1154+
"GlobalLinkageKind::ExternalLinkage">:$linkage,
1155+
OptionalAttr<StrAttr>:$sym_visibility);
1156+
let regions = (region AnyRegion:$body);
1157+
let skipDefaultBuilders = 1;
1158+
1159+
let builders = [OpBuilder<(ins
1160+
"StringRef":$name, "FunctionType":$type,
1161+
CArg<"GlobalLinkageKind", "GlobalLinkageKind::ExternalLinkage">:$linkage,
1162+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
1163+
CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs)
1164+
>];
1165+
1166+
let extraClassDeclaration = [{
1167+
// Returns the type of this function.
1168+
// FIXME: We should derive this via the ODS `type` param.
1169+
FunctionType getType() {
1170+
return getTypeAttr().getValue().cast<FunctionType>();
1171+
}
1172+
1173+
/// Returns the region on the current operation that is callable. This may
1174+
/// return null in the case of an external callable object, e.g. an external
1175+
/// function.
1176+
::mlir::Region *getCallableRegion() {
1177+
return isExternal() ? nullptr : &getBody();
1178+
}
1179+
1180+
/// Returns the results types that the callable region produces when
1181+
/// executed.
1182+
ArrayRef<Type> getCallableResults() { return getType().getResults(); }
1183+
1184+
/// This trait needs access to the hooks defined below.
1185+
friend struct FunctionOpInterfaceTrait<mlir::FuncOp>;
1186+
1187+
/// Returns the argument types of this function.
1188+
ArrayRef<Type> getArgumentTypes() { return getType().getInputs(); }
1189+
1190+
/// Returns the result types of this function.
1191+
ArrayRef<Type> getResultTypes() { return getType().getResults(); }
1192+
1193+
/// Hook for OpTrait::FunctionOpInterfaceTrait, called after verifying that
1194+
/// the 'type' attribute is present and checks if it holds a function type.
1195+
/// Ensures getType, getNumFuncArguments, and getNumFuncResults can be
1196+
/// called safely.
1197+
LogicalResult verifyType();
1198+
}];
1199+
let parser = [{ return ::parseCIRFuncOp(parser, result); }];
1200+
let printer = [{ return ::print(*this, p); }];
1201+
let verifier = [{ return ::verify(*this); }];
1202+
}
1203+
1204+
#endif // MLIR_CIR_DIALECT_CIR_OPS

mlir/lib/Dialect/CIR/IR/CIRDialect.cpp

Lines changed: 196 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1717
#include "mlir/IR/Builders.h"
1818
#include "mlir/IR/DialectImplementation.h"
19+
#include "mlir/IR/FunctionImplementation.h"
1920
#include "mlir/IR/OpDefinition.h"
2021
#include "mlir/IR/OpImplementation.h"
2122
#include "mlir/IR/TypeUtilities.h"
@@ -58,6 +59,54 @@ void cir::CIRDialect::initialize() {
5859
addInterfaces<CIROpAsmDialectInterface>();
5960
}
6061

62+
//===----------------------------------------------------------------------===//
63+
// Helpers
64+
//===----------------------------------------------------------------------===//
65+
66+
// Parses one of the keywords provided in the list `keywords` and returns the
67+
// position of the parsed keyword in the list. If none of the keywords from the
68+
// list is parsed, returns -1.
69+
static int parseOptionalKeywordAlternative(OpAsmParser &parser,
70+
ArrayRef<StringRef> keywords) {
71+
for (auto en : llvm::enumerate(keywords)) {
72+
if (succeeded(parser.parseOptionalKeyword(en.value())))
73+
return en.index();
74+
}
75+
return -1;
76+
}
77+
78+
namespace {
79+
template <typename Ty>
80+
struct EnumTraits {};
81+
82+
#define REGISTER_ENUM_TYPE(Ty) \
83+
template <> \
84+
struct EnumTraits<Ty> { \
85+
static StringRef stringify(Ty value) { return stringify##Ty(value); } \
86+
static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \
87+
}
88+
89+
REGISTER_ENUM_TYPE(GlobalLinkageKind);
90+
} // namespace
91+
92+
/// Parse an enum from the keyword, or default to the provided default value.
93+
/// The return type is the enum type by default, unless overriden with the
94+
/// second template argument.
95+
/// TODO: teach other places in this file to use this function.
96+
template <typename EnumTy, typename RetTy = EnumTy>
97+
static RetTy parseOptionalCIRKeyword(OpAsmParser &parser,
98+
OperationState &result,
99+
EnumTy defaultValue) {
100+
SmallVector<StringRef, 10> names;
101+
for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
102+
names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
103+
104+
int index = parseOptionalKeywordAlternative(parser, names);
105+
if (index == -1)
106+
return static_cast<RetTy>(defaultValue);
107+
return static_cast<RetTy>(index);
108+
}
109+
61110
//===----------------------------------------------------------------------===//
62111
// ConstantOp
63112
//===----------------------------------------------------------------------===//
@@ -190,7 +239,7 @@ static LogicalResult verify(cir::CastOp castOp) {
190239
//===----------------------------------------------------------------------===//
191240

192241
static mlir::LogicalResult checkReturnAndFunction(ReturnOp op,
193-
FuncOp function) {
242+
mlir::FuncOp function) {
194243
// ReturnOps currently only have a single optional operand.
195244
if (op.getNumOperands() > 1)
196245
return op.emitOpError() << "expects at most 1 return operand";
@@ -223,11 +272,11 @@ static mlir::LogicalResult verify(ReturnOp op) {
223272
// Returns can be present in multiple different scopes, get the
224273
// wrapping function and start from there.
225274
auto *fnOp = op->getParentOp();
226-
while (!isa<FuncOp>(fnOp))
275+
while (!isa<mlir::FuncOp>(fnOp))
227276
fnOp = fnOp->getParentOp();
228277

229278
// Make sure return types match function return type.
230-
if (checkReturnAndFunction(op, cast<FuncOp>(fnOp)).failed())
279+
if (checkReturnAndFunction(op, cast<mlir::FuncOp>(fnOp)).failed())
231280
return failure();
232281

233282
return success();
@@ -495,7 +544,7 @@ static LogicalResult verify(ScopeOp op) {
495544

496545
static mlir::LogicalResult verify(YieldOp op) {
497546
auto isDominatedByLoopOrSwitch = [](Operation *parentOp) {
498-
while (!llvm::isa<FuncOp>(parentOp)) {
547+
while (!llvm::isa<mlir::FuncOp>(parentOp)) {
499548
if (llvm::isa<cir::SwitchOp, cir::LoopOp>(parentOp))
500549
return true;
501550
parentOp = parentOp->getParentOp();
@@ -504,7 +553,7 @@ static mlir::LogicalResult verify(YieldOp op) {
504553
};
505554

506555
auto isDominatedByLoop = [](Operation *parentOp) {
507-
while (!llvm::isa<FuncOp>(parentOp)) {
556+
while (!llvm::isa<mlir::FuncOp>(parentOp)) {
508557
if (llvm::isa<cir::LoopOp>(parentOp))
509558
return true;
510559
parentOp = parentOp->getParentOp();
@@ -1013,15 +1062,15 @@ static LogicalResult verify(GlobalOp op) {
10131062
}
10141063

10151064
switch (op.linkage()) {
1016-
case mlir::cir::GlobalLinkageKind::InternalLinkage:
1017-
case mlir::cir::GlobalLinkageKind::PrivateLinkage:
1065+
case GlobalLinkageKind::InternalLinkage:
1066+
case GlobalLinkageKind::PrivateLinkage:
10181067
if (op.isPublic())
10191068
return op->emitError()
10201069
<< "public visibility not allowed with '"
10211070
<< stringifyGlobalLinkageKind(op.linkage()) << "' linkage";
10221071
break;
1023-
case mlir::cir::GlobalLinkageKind::ExternalLinkage:
1024-
case mlir::cir::GlobalLinkageKind::ExternalWeakLinkage:
1072+
case GlobalLinkageKind::ExternalLinkage:
1073+
case GlobalLinkageKind::ExternalWeakLinkage:
10251074
if (op.isPrivate())
10261075
return op->emitError()
10271076
<< "private visibility not allowed with '"
@@ -1072,6 +1121,144 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
10721121
return success();
10731122
}
10741123

1124+
//===----------------------------------------------------------------------===//
1125+
// FuncOp
1126+
//===----------------------------------------------------------------------===//
1127+
1128+
/// Returns the name used for the linkage attribute. This *must* correspond to
1129+
/// the name of the attribute in ODS.
1130+
static StringRef getLinkageAttrName() { return "linkage"; }
1131+
1132+
void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
1133+
StringRef name, FunctionType type,
1134+
GlobalLinkageKind linkage,
1135+
ArrayRef<NamedAttribute> attrs,
1136+
ArrayRef<DictionaryAttr> argAttrs) {
1137+
result.addRegion();
1138+
result.addAttribute(SymbolTable::getSymbolAttrName(),
1139+
builder.getStringAttr(name));
1140+
result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
1141+
result.addAttribute(getLinkageAttrName(), GlobalLinkageKindAttr::get(
1142+
builder.getContext(), linkage));
1143+
result.attributes.append(attrs.begin(), attrs.end());
1144+
if (argAttrs.empty())
1145+
return;
1146+
1147+
function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs,
1148+
/*resultAttrs=*/llvm::None);
1149+
}
1150+
1151+
static ParseResult parseCIRFuncOp(OpAsmParser &parser, OperationState &state) {
1152+
// Default to external linkage if no keyword is provided.
1153+
state.addAttribute(
1154+
getLinkageAttrName(),
1155+
GlobalLinkageKindAttr::get(
1156+
parser.getContext(),
1157+
parseOptionalCIRKeyword<GlobalLinkageKind>(
1158+
parser, state, GlobalLinkageKind::ExternalLinkage)));
1159+
1160+
StringAttr nameAttr;
1161+
SmallVector<OpAsmParser::OperandType, 8> entryArgs;
1162+
SmallVector<NamedAttrList, 1> argAttrs;
1163+
SmallVector<Optional<Location>, 1> argLocations;
1164+
SmallVector<NamedAttrList, 1> resultAttrs;
1165+
SmallVector<Type, 8> argTypes;
1166+
SmallVector<Type, 4> resultTypes;
1167+
auto &builder = parser.getBuilder();
1168+
1169+
// Parse the name as a symbol.
1170+
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1171+
state.attributes))
1172+
return failure();
1173+
1174+
// Parse the function signature.
1175+
bool isVariadic = false;
1176+
if (function_interface_impl::parseFunctionSignature(
1177+
parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
1178+
argLocations, isVariadic, resultTypes, resultAttrs))
1179+
return failure();
1180+
1181+
auto fnType = builder.getFunctionType(argTypes, resultTypes);
1182+
state.addAttribute(function_interface_impl::getTypeAttrName(),
1183+
TypeAttr::get(fnType));
1184+
1185+
// If additional attributes are present, parse them.
1186+
if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
1187+
return failure();
1188+
1189+
// Add the attributes to the function arguments.
1190+
assert(argAttrs.size() == argTypes.size());
1191+
assert(resultAttrs.size() == resultTypes.size());
1192+
function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
1193+
resultAttrs);
1194+
1195+
// Parse the optional function body.
1196+
auto *body = state.addRegion();
1197+
OptionalParseResult result = parser.parseOptionalRegion(
1198+
*body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
1199+
return failure(result.hasValue() && failed(*result));
1200+
}
1201+
1202+
static void print(cir::FuncOp op, OpAsmPrinter &p) {
1203+
p << ' ';
1204+
if (op.linkage() != GlobalLinkageKind::ExternalLinkage)
1205+
p << stringifyGlobalLinkageKind(op.linkage()) << ' ';
1206+
1207+
// Print function name, signature, and control.
1208+
p.printSymbolName(op.sym_name());
1209+
auto fnType = op.getType();
1210+
function_interface_impl::printFunctionSignature(p, op, fnType.getInputs(),
1211+
/*isVariadic=*/false,
1212+
fnType.getResults());
1213+
function_interface_impl::printFunctionAttributes(p, op, fnType.getNumInputs(),
1214+
fnType.getNumResults(), {});
1215+
1216+
// Print the body if this is not an external function.
1217+
Region &body = op.body();
1218+
if (!body.empty())
1219+
p.printRegion(body, /*printEntryBlockArgs=*/false,
1220+
/*printBlockTerminators=*/true);
1221+
}
1222+
1223+
// Hook for OpTrait::FunctionLike, called after verifying that the 'type'
1224+
// attribute is present. This can check for preconditions of the
1225+
// getNumArguments hook not failing.
1226+
LogicalResult cir::FuncOp::verifyType() {
1227+
auto type = getTypeAttr().getValue();
1228+
if (!type.isa<FunctionType>())
1229+
return emitOpError("requires '" + getTypeAttrName() +
1230+
"' attribute of function type");
1231+
if (getType().getNumResults() > 1)
1232+
return emitOpError("cannot have more than one result");
1233+
return success();
1234+
}
1235+
1236+
// Verifies linkage types, similar to LLVM:
1237+
// - functions don't have 'common' linkage
1238+
// - external functions have 'external' or 'extern_weak' linkage
1239+
static LogicalResult verify(cir::FuncOp op) {
1240+
if (op.linkage() == cir::GlobalLinkageKind::CommonLinkage)
1241+
return op.emitOpError()
1242+
<< "functions cannot have '"
1243+
<< stringifyGlobalLinkageKind(cir::GlobalLinkageKind::CommonLinkage)
1244+
<< "' linkage";
1245+
1246+
if (op.isExternal()) {
1247+
if (op.linkage() != cir::GlobalLinkageKind::ExternalLinkage &&
1248+
op.linkage() != cir::GlobalLinkageKind::ExternalWeakLinkage)
1249+
return op.emitOpError()
1250+
<< "external functions must have '"
1251+
<< stringifyGlobalLinkageKind(
1252+
cir::GlobalLinkageKind::ExternalLinkage)
1253+
<< "' or '"
1254+
<< stringifyGlobalLinkageKind(
1255+
cir::GlobalLinkageKind::ExternalWeakLinkage)
1256+
<< "' linkage";
1257+
return success();
1258+
}
1259+
return success();
1260+
}
1261+
10751262
//===----------------------------------------------------------------------===//
10761263
// CIR defined traits
10771264
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)