-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Optimize FMA codegen base on the overwritten #58196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 34 commits
ee2c0b6
46d0011
cce4bda
b825291
f615e39
b698036
7d9c0d6
1344d92
029a9b5
f2a371f
9955389
7c56653
1d51caa
091133e
9a6ae44
ffcff76
5641f8f
b7312ac
a325fe3
0f950dd
33a596d
5da9368
c3a9f07
9e356aa
f8159bc
18bbe4d
2ca2524
17bd967
eed5912
43c5034
5ef70a5
bfa6924
12f260b
5ca658e
ec4ef66
aa93a85
c66a018
ff5a433
a4657c7
75d7a37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -21818,6 +21818,52 @@ uint16_t GenTreeLclVarCommon::GetLclOffs() const | |||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| #if defined(TARGET_XARCH) && defined(FEATURE_HW_INTRINSICS) | ||||||||||||||||||
| //------------------------------------------------------------------------ | ||||||||||||||||||
| // GetResultOpNumForFMA: check if the result is written into one of the operands. | ||||||||||||||||||
| // In the case that none of the operand is overwritten, check if any of them is lastUse. | ||||||||||||||||||
| // | ||||||||||||||||||
| // Return Value: | ||||||||||||||||||
| // The operand number overwritten or lastUse. 0 is the default value, where the result is written into | ||||||||||||||||||
| // a destination that is not one of the source operands and there is no last use op. | ||||||||||||||||||
| // | ||||||||||||||||||
| unsigned GenTreeHWIntrinsic::GetResultOpNumForFMA(GenTree* use, GenTree* op1, GenTree* op2, GenTree* op3) | ||||||||||||||||||
| { | ||||||||||||||||||
| // only FMA intrinsic node should call into this function | ||||||||||||||||||
| assert(HWIntrinsicInfo::lookupIsa(gtHWIntrinsicId) == InstructionSet_FMA); | ||||||||||||||||||
| if (use != nullptr && use->OperIs(GT_STORE_LCL_VAR)) | ||||||||||||||||||
| { | ||||||||||||||||||
| // For store_lcl_var, check if any op is overwritten | ||||||||||||||||||
|
|
||||||||||||||||||
| GenTreeLclVarCommon* overwritten = use->AsLclVarCommon(); | ||||||||||||||||||
| unsigned overwrittenLclNum = overwritten->GetLclNum(); | ||||||||||||||||||
| if (op1->IsLocal() && op1->AsLclVarCommon()->GetLclNum() == overwrittenLclNum) | ||||||||||||||||||
| { | ||||||||||||||||||
| return 1; | ||||||||||||||||||
| } | ||||||||||||||||||
| else if (op2->IsLocal() && op2->AsLclVarCommon()->GetLclNum() == overwrittenLclNum) | ||||||||||||||||||
| { | ||||||||||||||||||
| return 2; | ||||||||||||||||||
| } | ||||||||||||||||||
| else if (op3->IsLocal() && op3->AsLclVarCommon()->GetLclNum() == overwrittenLclNum) | ||||||||||||||||||
| { | ||||||||||||||||||
| return 3; | ||||||||||||||||||
| } | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| // If no overwritten op, check if there is any last use op | ||||||||||||||||||
|
|
||||||||||||||||||
| if (op1->OperIs(GT_LCL_VAR) && op1->IsLastUse(0)) | ||||||||||||||||||
| return 1; | ||||||||||||||||||
| else if (op3->OperIs(GT_LCL_VAR) && op3->IsLastUse(0)) | ||||||||||||||||||
| return 3; | ||||||||||||||||||
| else if (op2->OperIs(GT_LCL_VAR) && op2->IsLastUse(0)) | ||||||||||||||||||
| return 2; | ||||||||||||||||||
|
||||||||||||||||||
| else if (op3->OperIs(GT_LCL_VAR) && op3->IsLastUse(0)) | |
| return 3; | |
| else if (op2->OperIs(GT_LCL_VAR) && op2->IsLastUse(0)) | |
| return 2; | |
| else if (op2->OperIs(GT_LCL_VAR) && op2->IsLastUse(0)) | |
| return 2; | |
| else if (op3->OperIs(GT_LCL_VAR) && op3->IsLastUse(0)) | |
| return 3; |
The reasoning is that this method is picking a preference for "overwritten op".
Preferencing op1 as the first check here makes sense because scalar ops "copy upper bits" and therefore if we're in that scenario, op1 is the only operand that can "be the target" as it were, the others will have to be contained or delay free.
Preferencing op2 after that (as the secondary preference) simply keeps it consistent with the op1->IsLocal checks above and results in the "least" amount of operand swapping if we order the containment checks accordingly.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2106,7 +2106,9 @@ void CodeGen::genFMAIntrinsic(GenTreeHWIntrinsic* node) | |
| NamedIntrinsic intrinsicId = node->gtHWIntrinsicId; | ||
| var_types baseType = node->GetSimdBaseType(); | ||
| emitAttr attr = emitActualTypeSize(Compiler::getSIMDTypeForSize(node->GetSimdSize())); | ||
| instruction ins = HWIntrinsicInfo::lookupIns(intrinsicId, baseType); | ||
| instruction ins = HWIntrinsicInfo::lookupIns(intrinsicId, baseType); // 213 form | ||
| instruction _132form = (instruction)(ins - 1); | ||
| instruction _231form = (instruction)(ins + 1); | ||
| GenTree* op1 = node->gtGetOp1(); | ||
| regNumber targetReg = node->GetRegNum(); | ||
|
|
||
|
|
@@ -2122,43 +2124,75 @@ void CodeGen::genFMAIntrinsic(GenTreeHWIntrinsic* node) | |
| argList = argList->Rest(); | ||
| GenTree* op3 = argList->Current(); | ||
|
|
||
| regNumber op1Reg; | ||
| regNumber op2Reg; | ||
| regNumber op1NodeReg = op1->GetRegNum(); | ||
| regNumber op2NodeReg = op2->GetRegNum(); | ||
| regNumber op3NodeReg = op3->GetRegNum(); | ||
|
|
||
| GenTree* emitOp1 = op1; | ||
| GenTree* emitOp2 = op2; | ||
| GenTree* emitOp3 = op3; | ||
|
|
||
| bool isCommutative = false; | ||
| const bool copiesUpperBits = HWIntrinsicInfo::CopiesUpperBits(intrinsicId); | ||
|
|
||
| // Intrinsics with CopyUpperBits semantics cannot have op1 be contained | ||
| assert(!copiesUpperBits || !op1->isContained()); | ||
|
|
||
| if (op2->isContained() || op2->isUsedFromSpillTemp()) | ||
| if (op1->isContained() || op1->isUsedFromSpillTemp()) | ||
| { | ||
| // 132 form: op1 = (op1 * op3) + [op2] | ||
|
|
||
| ins = (instruction)(ins - 1); | ||
| op1Reg = op1->GetRegNum(); | ||
| op2Reg = op3->GetRegNum(); | ||
| op3 = op2; | ||
| if (targetReg == op2NodeReg) | ||
| { | ||
| std::swap(emitOp1, emitOp2); | ||
| // op2 = ([op1] * op2) + op3 | ||
| // 132 form: XMM1 = (XMM1 * [XMM3]) + XMM2 | ||
| ins = _132form; | ||
| std::swap(emitOp2, emitOp3); | ||
| } | ||
| else | ||
| { | ||
| // targetReg == op3NodeReg or targetReg == ? | ||
| // op3 = ([op1] * op2) + op3 | ||
| // 231 form: XMM1 = (XMM2 * [XMM3]) + XMM1 | ||
| ins = _231form; | ||
| std::swap(emitOp1, emitOp3); | ||
| } | ||
| } | ||
| else if (op1->isContained() || op1->isUsedFromSpillTemp()) | ||
| else if (op2->isContained() || op2->isUsedFromSpillTemp()) | ||
| { | ||
| // 231 form: op3 = (op2 * op3) + [op1] | ||
|
|
||
| ins = (instruction)(ins + 1); | ||
| op1Reg = op3->GetRegNum(); | ||
| op2Reg = op2->GetRegNum(); | ||
| op3 = op1; | ||
| if (targetReg == op3NodeReg) | ||
|
||
| { | ||
| // op3 = (op1 * [op2]) + op3 | ||
| // 231 form: XMM1 = (XMM2 * [XMM3]) + XMM1 | ||
| ins = _231form; | ||
| std::swap(emitOp1, emitOp3); | ||
| } | ||
| else | ||
| { | ||
| // targetReg == op1NodeReg or targetReg == ? | ||
| // op1 = (op1 * [op2]) + op3 | ||
| // 132 form: XMM1 = (XMM1 * [XMM3]) + XMM2 | ||
| ins = _132form; | ||
| } | ||
| std::swap(emitOp2, emitOp3); | ||
| } | ||
| else | ||
| { | ||
| // 213 form: op1 = (op2 * op1) + [op3] | ||
|
|
||
| op1Reg = op1->GetRegNum(); | ||
| op2Reg = op2->GetRegNum(); | ||
|
|
||
| isCommutative = !copiesUpperBits; | ||
| // targetReg could be op1NodeReg, op2NodeReg, or not equal to any op | ||
| // op1 = (op1 * op2) + [op3] or op2 = (op1 * op2) + [op3] | ||
| // ? = (op1 * op2) + [op3] or ? = (op1 * op2) + op3 | ||
| // 213 form: XMM1 = (XMM2 * XMM1) + [XMM3] | ||
| isCommutative = copiesUpperBits; | ||
|
||
| if (targetReg == op2NodeReg) | ||
|
||
| { | ||
| // op2 = (op1 * op2) + [op3] | ||
| // 213 form: XMM1 = (XMM2 * XMM1) + [XMM3] | ||
| std::swap(emitOp1, emitOp2); | ||
| } | ||
| } | ||
|
|
||
| regNumber op1Reg = emitOp1->GetRegNum(); | ||
| regNumber op2Reg = emitOp2->GetRegNum(); | ||
|
|
||
| if (isCommutative && (op1Reg != targetReg) && (op2Reg == targetReg)) | ||
|
||
| { | ||
| assert(node->isRMWHWIntrinsic(compiler)); | ||
|
|
@@ -2173,8 +2207,7 @@ void CodeGen::genFMAIntrinsic(GenTreeHWIntrinsic* node) | |
| op2Reg = op1Reg; | ||
| op1Reg = targetReg; | ||
| } | ||
|
|
||
| genHWIntrinsic_R_R_R_RM(ins, attr, targetReg, op1Reg, op2Reg, op3); | ||
| genHWIntrinsic_R_R_R_RM(ins, attr, targetReg, op1Reg, op2Reg, emitOp3); | ||
| genProduceReg(node); | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6335,38 +6335,51 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node) | |
| { | ||
| if ((intrinsicId >= NI_FMA_MultiplyAdd) && (intrinsicId <= NI_FMA_MultiplySubtractNegatedScalar)) | ||
| { | ||
| bool supportsRegOptional = false; | ||
| bool supportsOp1RegOptional = false; | ||
| bool supportsOp2RegOptional = false; | ||
| bool supportsOp3RegOptional = false; | ||
| unsigned resultOpNum = 0; | ||
| LIR::Use use; | ||
| GenTree* user = nullptr; | ||
|
|
||
| if (BlockRange().TryGetUse(node, &use)) | ||
| { | ||
| user = use.User(); | ||
| } | ||
| resultOpNum = node->GetResultOpNumForFMA(user, op1, op2, op3); | ||
|
|
||
| // Prioritize Containable op. Check if any one of the op is containable first. | ||
| // Set op regOptional only if none of them is containable. | ||
|
|
||
| if (IsContainableHWIntrinsicOp(node, op3, &supportsRegOptional)) | ||
| if (resultOpNum != 1 && IsContainableHWIntrinsicOp(node, op1, &supportsOp1RegOptional) && | ||
|
||
| !HWIntrinsicInfo::CopiesUpperBits(intrinsicId)) | ||
|
||
| { | ||
| // 213 form: op1 = (op2 * op1) + [op3] | ||
| MakeSrcContained(node, op3); | ||
| // result = ([op1] * op2) + op3 | ||
| MakeSrcContained(node, op1); | ||
| } | ||
| else if (IsContainableHWIntrinsicOp(node, op2, &supportsRegOptional)) | ||
| else if (resultOpNum != 2 && IsContainableHWIntrinsicOp(node, op2, &supportsOp2RegOptional)) | ||
| { | ||
| // 132 form: op1 = (op1 * op3) + [op2] | ||
| // result = (op1 * [op2]) + op3 | ||
| MakeSrcContained(node, op2); | ||
| } | ||
| else if (IsContainableHWIntrinsicOp(node, op1, &supportsRegOptional)) | ||
| else if (resultOpNum != 3 && IsContainableHWIntrinsicOp(node, op3, &supportsOp3RegOptional)) | ||
| { | ||
| // Intrinsics with CopyUpperBits semantics cannot have op1 be contained | ||
|
|
||
| if (!HWIntrinsicInfo::CopiesUpperBits(intrinsicId)) | ||
| { | ||
| // 231 form: op3 = (op2 * op3) + [op1] | ||
| MakeSrcContained(node, op1); | ||
| } | ||
| // result = (op1 * op2) + [op3] | ||
| MakeSrcContained(node, op3); | ||
| } | ||
| else if (resultOpNum != 1 && !HWIntrinsicInfo::CopiesUpperBits(intrinsicId)) | ||
| { | ||
| assert(supportsOp1RegOptional); | ||
| op1->SetRegOptional(); | ||
| } | ||
tannergooding marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| else if (resultOpNum != 2) | ||
| { | ||
| assert(supportsOp2RegOptional); | ||
| op2->SetRegOptional(); | ||
| } | ||
| else | ||
| { | ||
| assert(supportsRegOptional); | ||
|
|
||
| // TODO-XArch-CQ: Technically any one of the three operands can | ||
| // be reg-optional. With a limitation on op1 where | ||
| // it can only be so if CopyUpperBits is off. | ||
| // https://github.com/dotnet/runtime/issues/6358 | ||
|
|
||
| // 213 form: op1 = (op2 * op1) + op3 | ||
| assert(supportsOp3RegOptional); | ||
| op3->SetRegOptional(); | ||
tannergooding marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2328,48 +2328,77 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree) | |
|
|
||
| const bool copiesUpperBits = HWIntrinsicInfo::CopiesUpperBits(intrinsicId); | ||
|
|
||
| // Intrinsics with CopyUpperBits semantics cannot have op1 be contained | ||
| assert(!copiesUpperBits || !op1->isContained()); | ||
| unsigned resultOpNum = 0; | ||
| LIR::Use use; | ||
| GenTree* user = nullptr; | ||
|
|
||
| if (op2->isContained()) | ||
| if (LIR::AsRange(blockSequence[curBBSeqNum]).TryGetUse(intrinsicTree, &use)) | ||
| { | ||
| // 132 form: op1 = (op1 * op3) + [op2] | ||
| user = use.User(); | ||
| } | ||
| resultOpNum = intrinsicTree->GetResultOpNumForFMA(user, op1, op2, op3); | ||
|
|
||
| tgtPrefUse = BuildUse(op1); | ||
| // Intrinsics with CopyUpperBits semantics cannot have op1 be contained | ||
| assert(!copiesUpperBits || !op1->isContained()); | ||
|
|
||
| srcCount += 1; | ||
| srcCount += BuildOperandUses(op2); | ||
| srcCount += BuildDelayFreeUses(op3, op1); | ||
| unsigned containedOpNum = 0; | ||
|
|
||
| if (op1->isContained() || op1->IsRegOptional()) | ||
| { | ||
| containedOpNum = 1; | ||
| } | ||
| else if (op1->isContained()) | ||
| else if (op2->isContained() || op2->IsRegOptional()) | ||
| { | ||
| // 231 form: op3 = (op2 * op3) + [op1] | ||
|
|
||
| tgtPrefUse = BuildUse(op3); | ||
|
|
||
| srcCount += BuildOperandUses(op1); | ||
| srcCount += BuildDelayFreeUses(op2, op1); | ||
| srcCount += 1; | ||
| containedOpNum = 2; | ||
| } | ||
| else | ||
| { | ||
| // 213 form: op1 = (op2 * op1) + [op3] | ||
| assert(op3->isContained() || op3->IsRegOptional()); | ||
|
||
| containedOpNum = 3; | ||
| } | ||
|
|
||
| tgtPrefUse = BuildUse(op1); | ||
| srcCount += 1; | ||
| GenTree* emitOp1 = op1; | ||
| GenTree* emitOp2 = op2; | ||
| GenTree* emitOp3 = op3; | ||
|
|
||
| if (copiesUpperBits) | ||
| // Intrinsics with CopyUpperBits semantics must have op1 as target | ||
| if (containedOpNum == 1 && !copiesUpperBits) | ||
tannergooding marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| { | ||
|
||
| if (resultOpNum != 3) | ||
|
||
| { | ||
| srcCount += BuildDelayFreeUses(op2, op1); | ||
| // op2 = ([op1] * op2) + op3 | ||
| std::swap(emitOp2, emitOp3); | ||
| } | ||
| else | ||
|
|
||
| // else: op3 = ([op1] * op2) + op3 | ||
| std::swap(emitOp1, emitOp3); | ||
| } | ||
| else if (containedOpNum == 3) | ||
| { | ||
| if (resultOpNum == 2 && !copiesUpperBits) | ||
| { | ||
| tgtPrefUse2 = BuildUse(op2); | ||
| srcCount += 1; | ||
| // op2 = (op1 * op2) + [op3] | ||
| std::swap(emitOp1, emitOp2); | ||
| } | ||
| // else: op1 = (op1 * op2) + [op3] | ||
| } | ||
| else | ||
| { | ||
| assert(containedOpNum == 2); | ||
|
||
| // op1 = (op1 * [op2]) + op3 | ||
| std::swap(emitOp2, emitOp3); | ||
|
|
||
| srcCount += op3->isContained() ? BuildOperandUses(op3) : BuildDelayFreeUses(op3, op1); | ||
| if (resultOpNum == 3 && !copiesUpperBits) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just capturing a comment, I don't think we need to do anything in this PR. I think the logic around |
||
| { | ||
| // op3 = (op1 * [op2]) + op3 | ||
| std::swap(emitOp1, emitOp2); | ||
| } | ||
| } | ||
| tgtPrefUse = BuildUse(emitOp1); | ||
|
|
||
| srcCount += 1; | ||
| srcCount += BuildDelayFreeUses(emitOp2, emitOp1); | ||
| srcCount += emitOp3->isContained() ? BuildOperandUses(emitOp3) : BuildDelayFreeUses(emitOp3, emitOp1); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a lot smaller and easier to follow now 🎉 |
||
|
|
||
| buildUses = false; | ||
| break; | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.