diff --git a/src/coreclr/vm/asyncthunks.cpp b/src/coreclr/vm/asyncthunks.cpp index c47ae047873c9f..e48fa1410ecaab 100644 --- a/src/coreclr/vm/asyncthunks.cpp +++ b/src/coreclr/vm/asyncthunks.cpp @@ -65,7 +65,7 @@ void MethodDesc::EmitTaskReturningThunk(MethodDesc* pAsyncCallVariant, MetaSig& _ASSERTE(!pAsyncCallVariant->IsAsyncThunkMethod()); // Emits roughly the following code: - // + // // ExecutionAndSyncBlockStore store = default; // store.Push(); // try @@ -569,7 +569,15 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig } // other(arg) - pCode->EmitCALL(userFuncToken, localArg, 1); + if (pTaskReturningVariant->IsAbstract()) + { + _ASSERTE(pTaskReturningVariant->IsCLRToCOMCall()); + pCode->EmitCALLVIRT(userFuncToken, localArg, 1); + } + else + { + pCode->EmitCALL(userFuncToken, localArg, 1); + } TypeHandle thLogicalRetType = msig.GetRetTypeHandleThrowing(); if (IsValueTaskAsyncThunk()) diff --git a/src/coreclr/vm/class.cpp b/src/coreclr/vm/class.cpp index cb984506c3f9aa..dbc9b2514eb188 100644 --- a/src/coreclr/vm/class.cpp +++ b/src/coreclr/vm/class.cpp @@ -2828,13 +2828,15 @@ void SparseVTableMap::AllocOrExpand() } //******************************************************************************* -// While building mapping list, record a gap in VTable slot numbers. +// While building mapping list, record a gap in VTable slot numbers or MT slots. +// A positive number indicates a gap in the VTable slot numbers. +// A negative number indicates a gap in the MT slots. void SparseVTableMap::RecordGap(WORD StartMTSlot, WORD NumSkipSlots) { STANDARD_VM_CONTRACT; _ASSERTE((StartMTSlot == 0) || (StartMTSlot > m_MTSlot)); - _ASSERTE(NumSkipSlots > 0); + _ASSERTE(NumSkipSlots != 0); // We use the information about the current gap to complete a map entry for // the last non-gap. There is a special case where the vtable begins with a @@ -2860,6 +2862,14 @@ void SparseVTableMap::RecordGap(WORD StartMTSlot, WORD NumSkipSlots) m_MapEntries++; } +//******************************************************************************* +// While building mapping list, record an excluded MT slot. +void SparseVTableMap::RecordExcludedMethod(WORD MTSlot) +{ + WRAPPER_NO_CONTRACT; + return RecordGap(MTSlot, -1); +} + //******************************************************************************* // Finish creation of mapping list. void SparseVTableMap::FinalizeMapping(WORD TotalMTSlots) diff --git a/src/coreclr/vm/class.h b/src/coreclr/vm/class.h index 6cd8e5a6939ee3..d130506c408898 100644 --- a/src/coreclr/vm/class.h +++ b/src/coreclr/vm/class.h @@ -289,6 +289,9 @@ class SparseVTableMap // occurs. void RecordGap(WORD StartMTSlot, WORD NumSkipSlots); + // Record that the method table slot at MTSlot is excluded from the VT slots. + void RecordExcludedMethod(WORD MTSlot); + // Then call FinalizeMapping to create the actual mapping list. void FinalizeMapping(WORD TotalMTSlots); diff --git a/src/coreclr/vm/clrtocomcall.cpp b/src/coreclr/vm/clrtocomcall.cpp index 0e6add81fc3d87..eff0398678e30c 100644 --- a/src/coreclr/vm/clrtocomcall.cpp +++ b/src/coreclr/vm/clrtocomcall.cpp @@ -212,6 +212,7 @@ I4ARRAYREF SetUpWrapperInfo(MethodDesc *pMD) MODE_COOPERATIVE; INJECT_FAULT(COMPlusThrowOM()); PRECONDITION(CheckPointer(pMD)); + PRECONDITION(!pMD->IsAsyncMethod()); } CONTRACTL_END; @@ -229,13 +230,6 @@ I4ARRAYREF SetUpWrapperInfo(MethodDesc *pMD) WrapperTypeArr = (I4ARRAYREF)AllocatePrimitiveArray(ELEMENT_TYPE_I4, numArgs); GCX_PREEMP(); - - - // TODO: (async) revisit and examine if this needs to be supported somehow - if (pMD->IsAsyncMethod()) - { - ThrowHR(COR_E_NOTSUPPORTED); - } // Collects ParamDef information in an indexed array where element 0 represents // the return type. @@ -511,11 +505,9 @@ UINT32 CLRToCOMLateBoundWorker( LPCUTF8 strMemberName; ULONG uSemantic; - // TODO: (async) revisit and examine if this needs to be supported somehow - if (pItfMD->IsAsyncMethod()) - { - ThrowHR(COR_E_NOTSUPPORTED); - } + // We should never see an async method here, as the async variant should go down + // the async stub path and call the non-async variant (which ends up here). + _ASSERTE(!pItfMD->IsAsyncMethod()); // See if there is property information for this member. hr = pItfMT->GetMDImport()->GetPropertyInfoForMethodDef(pItfMD->GetMemberDef(), &propToken, &strMemberName, &uSemantic); diff --git a/src/coreclr/vm/comcallablewrapper.cpp b/src/coreclr/vm/comcallablewrapper.cpp index 1ca1460454c19d..1b0c6eea5a4e45 100644 --- a/src/coreclr/vm/comcallablewrapper.cpp +++ b/src/coreclr/vm/comcallablewrapper.cpp @@ -3313,9 +3313,10 @@ void ComMethodTable::LayOutClassMethodTable() pCurrParentInteropMD = &pCurrParentInteropMT->pVTable[i]; pParentMD = pCurrParentInteropMD->pMD; - if (pMD && - !(pCurrInteropMD ? IsDuplicateClassItfMD(pCurrInteropMD, i) : IsDuplicateClassItfMD(pMD, i)) && - IsOverloadedComVisibleMember(pMD, pParentMD)) + if (pMD + && !(pCurrInteropMD ? IsDuplicateClassItfMD(pCurrInteropMD, i) : IsDuplicateClassItfMD(pMD, i)) + && IsOverloadedComVisibleMember(pMD, pParentMD) + && !pMD->IsAsyncMethod()) { // some bytes are reserved for CALL xxx before the method desc ComCallMethodDesc* pNewMD = (ComCallMethodDesc *) (pMethodDescMemory + COMMETHOD_PREPAD); @@ -3346,9 +3347,10 @@ void ComMethodTable::LayOutClassMethodTable() pCurrInteropMD = &pCurrInteropMT->pVTable[i]; pMD = pCurrInteropMD->pMD; - if (pMD && - !(pCurrInteropMD ? IsDuplicateClassItfMD(pCurrInteropMD, i) : IsDuplicateClassItfMD(pMD, i)) && - IsNewComVisibleMember(pMD)) + if (pMD + && !(pCurrInteropMD ? IsDuplicateClassItfMD(pCurrInteropMD, i) : IsDuplicateClassItfMD(pMD, i)) + && IsNewComVisibleMember(pMD) + && !pMD->IsAsyncMethod()) { // some bytes are reserved for CALL xxx before the method desc ComCallMethodDesc* pNewMD = (ComCallMethodDesc *) (pMethodDescMemory + COMMETHOD_PREPAD); @@ -3376,8 +3378,12 @@ void ComMethodTable::LayOutClassMethodTable() if (!it.IsVirtual()) { MethodDesc* pMD = it.GetMethodDesc(); - if (pMD != NULL && !IsDuplicateClassItfMD(pMD, it.GetSlotNumber()) && - IsNewComVisibleMember(pMD) && !pMD->IsStatic() && !pMD->IsCtor() + if (pMD != NULL + && !IsDuplicateClassItfMD(pMD, it.GetSlotNumber()) + && IsNewComVisibleMember(pMD) + && !pMD->IsStatic() + && !pMD->IsCtor() + && !pMD->IsAsyncMethod() && (!pCurrMT->IsValueType() || (GetClassInterfaceType() != clsIfAutoDual && IsStrictlyUnboxed(pMD)))) { // some bytes are reserved for CALL xxx before the method desc @@ -3558,6 +3564,8 @@ BOOL ComMethodTable::LayOutInterfaceMethodTable(MethodTable* pClsMT) ArrayList NewCOMMethodDescs; ComCallMethodDescArrayHolder NewCOMMethodDescsHolder(&NewCOMMethodDescs); + unsigned numVtableSlots = 0; + for (i = 0; i < cbSlots; i++) { // Some space for a CALL xx xx xx xx stub is reserved before the beginning of the MethodDesc @@ -3566,6 +3574,15 @@ BOOL ComMethodTable::LayOutInterfaceMethodTable(MethodTable* pClsMT) MethodDesc* pIntfMD = m_pMT->GetMethodDescForSlot(i); + if (pIntfMD->IsAsyncMethod()) + { + // Async methods are not supported on COM interfaces + // And we don't include them in the calculation of COM vtable slots. + continue; + } + + numVtableSlots++; + if (m_pMT->HasInstantiation()) { pIntfMD = MethodDesc::FindOrCreateAssociatedMethodDesc( @@ -3616,8 +3633,10 @@ BOOL ComMethodTable::LayOutInterfaceMethodTable(MethodTable* pClsMT) SLOT *pComVtableRW = (SLOT*)((BYTE*)pComVtable + writeableOffset); // Method descs are at the end of the vtable - // m_cbSlots interfaces methods + IUnk methods + // numVtableSlots interfaces methods + IUnk methods + unsigned cbEmittedSlots = 0; pMethodDescMemory = (BYTE *)&pComVtable[m_cbSlots]; + _ASSERTE(numVtableSlots <= m_cbSlots); for (i = 0; i < cbSlots; i++) { ComCallMethodDesc* pNewMD = (ComCallMethodDesc *) (pMethodDescMemory + COMMETHOD_PREPAD); @@ -3625,14 +3644,25 @@ BOOL ComMethodTable::LayOutInterfaceMethodTable(MethodTable* pClsMT) MethodDesc* pIntfMD = m_pMT->GetMethodDescForSlot(i); + if (pIntfMD->IsAsyncMethod()) + { + // Async methods are not supported on COM interfaces + // We skip them above in the vtable calculation + // so don't fill in the COM vtable slot here. + continue; + } + emitCOMStubCall(pNewMD, pNewMDRW, GetEEFuncEntryPoint(ComCallPreStub)); UINT slotIndex = (pIntfMD->GetComSlot() - cbExtraSlots); FillInComVtableSlot(pComVtableRW, slotIndex, pNewMD); pMethodDescMemory += (COMMETHOD_PREPAD + sizeof(ComCallMethodDesc)); + cbEmittedSlots++; } + _ASSERTE(numVtableSlots == cbEmittedSlots); + // Set the layout complete flag and release the lock. comMTWriterHolder.GetRW()->m_Flags |= enum_LayoutComplete; NewCOMMethodDescsHolder.SuppressRelease(); @@ -4248,9 +4278,10 @@ ComMethodTable* ComCallWrapperTemplate::CreateComMethodTableForClass(MethodTable pCurrParentInteropMD = &pCurrParentInteropMT->pVTable[i]; pParentMD = pCurrParentInteropMD->pMD; - if (pMD && - !(pCurrInteropMD ? IsDuplicateClassItfMD(pCurrInteropMD, i) : IsDuplicateClassItfMD(pMD, i)) && - IsOverloadedComVisibleMember(pMD, pParentMD)) + if (pMD + && !(pCurrInteropMD ? IsDuplicateClassItfMD(pCurrInteropMD, i) : IsDuplicateClassItfMD(pMD, i)) + && IsOverloadedComVisibleMember(pMD, pParentMD) + && !pMD->IsAsyncMethod()) { cbNewPublicMethods++; } @@ -4267,9 +4298,10 @@ ComMethodTable* ComCallWrapperTemplate::CreateComMethodTableForClass(MethodTable pCurrInteropMD = &pCurrInteropMT->pVTable[i]; pMD = pCurrInteropMD->pMD; - if (pMD && - !(pCurrInteropMD ? IsDuplicateClassItfMD(pCurrInteropMD, i) : IsDuplicateClassItfMD(pMD, i)) && - IsNewComVisibleMember(pMD)) + if (pMD + && !(pCurrInteropMD ? IsDuplicateClassItfMD(pCurrInteropMD, i) : IsDuplicateClassItfMD(pMD, i)) + && IsNewComVisibleMember(pMD) + && !pMD->IsAsyncMethod()) { cbNewPublicMethods++; } @@ -4282,9 +4314,13 @@ ComMethodTable* ComCallWrapperTemplate::CreateComMethodTableForClass(MethodTable if (!it.IsVirtual()) { MethodDesc* pMD = it.GetMethodDesc(); - if (pMD && !IsDuplicateClassItfMD(pMD, it.GetSlotNumber()) && IsNewComVisibleMember(pMD) && - !pMD->IsStatic() && !pMD->IsCtor() && - (!pCurrMT->IsValueType() || (ClassItfType != clsIfAutoDual && IsStrictlyUnboxed(pMD)))) + if (pMD + && !IsDuplicateClassItfMD(pMD, it.GetSlotNumber()) + && IsNewComVisibleMember(pMD) + && !pMD->IsStatic() + && !pMD->IsCtor() + && !pMD->IsAsyncMethod() + && (!pCurrMT->IsValueType() || (ClassItfType != clsIfAutoDual && IsStrictlyUnboxed(pMD)))) { cbNewPublicMethods++; } diff --git a/src/coreclr/vm/commtmemberinfomap.cpp b/src/coreclr/vm/commtmemberinfomap.cpp index 95843168768eea..c83165f02006c1 100644 --- a/src/coreclr/vm/commtmemberinfomap.cpp +++ b/src/coreclr/vm/commtmemberinfomap.cpp @@ -361,6 +361,7 @@ void ComMTMemberInfoMap::SetupPropsForIClassX(size_t sizeOfPtr) // Retrieve the method desc on the current class. This involves looking up the method // desc in the vtable if it is a virtual method. pMeth = pCMT->GetMethodDescForSlot(i); + _ASSERTE(!pMeth->IsAsyncMethod()); if (pMeth->IsVirtual()) { WORD wSlot = InteropMethodTableData::GetSlotForMethodDesc(m_pMT, pMeth); @@ -541,6 +542,15 @@ void ComMTMemberInfoMap::SetupPropsForInterface(size_t sizeOfPtr) { MethodDesc* pMD = m_pMT->GetMethodDescForSlot(iMD); _ASSERTE(pMD != NULL); + + if (pMD->IsAsyncMethod()) + { + // Async methods introduce mismatches in the .NET and COM vtables. + // We will need to remap slots. + bSlotRemap = true; + continue; + } + ULONG tmp = pMD->GetComSlot(); if (tmp < ulComSlotMin) @@ -552,10 +562,13 @@ void ComMTMemberInfoMap::SetupPropsForInterface(size_t sizeOfPtr) // Used a couple of times. MethodTable::MethodIterator it(m_pMT); - if (ulComSlotMax-ulComSlotMin >= nSlots) + if (ulComSlotMax - ulComSlotMin >= nSlots) { bSlotRemap = true; + } + if (bSlotRemap) + { // Resize the array. rSlotMap.ReSizeThrows(ulComSlotMax+1); @@ -566,7 +579,7 @@ void ComMTMemberInfoMap::SetupPropsForInterface(size_t sizeOfPtr) it.MoveToBegin(); for (; it.IsValid(); it.Next()) { - if (it.IsVirtual()) + if (it.IsVirtual() && !it.GetMethodDesc()->IsAsyncMethod()) { MethodDesc* pMD = it.GetMethodDesc(); _ASSERTE(pMD != NULL); @@ -590,7 +603,7 @@ void ComMTMemberInfoMap::SetupPropsForInterface(size_t sizeOfPtr) if (it.IsVirtual()) { pMeth = it.GetMethodDesc(); - if (pMeth != NULL) + if (pMeth != NULL && !pMeth->IsAsyncMethod()) { ULONG ixSlot = pMeth->GetComSlot(); if (bSlotRemap) @@ -607,6 +620,22 @@ void ComMTMemberInfoMap::SetupPropsForInterface(size_t sizeOfPtr) for (iMD=0; iMD < nSlots; ++iMD) { pMeth = m_MethodProps[iMD].pMeth; + if (pMeth == nullptr) + { + // For async methods, we skip .NET methods when building the COM vtable. + // So at some point we hit the end of the .NET methods before we run + // through all possible vtable slots. + // Record when we ran out here. +#ifdef _DEBUG + // In debug, validate that all remaining slots are null. + for (unsigned j = iMD; j < nSlots; ++j) + { + _ASSERTE(m_MethodProps[j].pMeth == nullptr); + } +#endif + nSlots = iMD; + break; + } GetMethodPropsForMeth(pMeth, iMD, m_MethodProps, m_sNames); } @@ -688,9 +717,7 @@ void ComMTMemberInfoMap::GetMethodPropsForMeth( // Generally don't munge function into a getter. rProps[ix].bFunction2Getter = FALSE; - // TODO: (async) revisit and examine if this needs to be supported somehow - if (pMeth->IsAsyncMethod()) - ThrowHR(COR_E_NOTSUPPORTED); + _ASSERTE(!pMeth->IsAsyncMethod()); // See if there is property information for this member. hr = pMeth->GetMDImport()->GetPropertyInfoForMethodDef(pMeth->GetMemberDef(), &pd, &pPropName, &uSemantic); @@ -1608,11 +1635,7 @@ void ComMTMemberInfoMap::PopulateMemberHashtable() // We are dealing with a method. MethodDesc *pMD = pProps->pMeth; - // TODO: (async) revisit and examine if this needs to be supported somehow - if (pMD->IsAsyncMethod()) - { - ThrowHR(COR_E_NOTSUPPORTED); // Probably this isn't right, and instead should be a skip, but a throw makes it easier to find if this is wrong - } + _ASSERTE(!pMD->IsAsyncMethod()); EEModuleTokenPair Key(pMD->GetMemberDef(), pMD->GetModule()); m_TokenToComMTMethodPropsMap.InsertValue(&Key, (HashDatum)pProps); } diff --git a/src/coreclr/vm/comtoclrcall.cpp b/src/coreclr/vm/comtoclrcall.cpp index cc8a1129618712..2f22baa3c14e96 100644 --- a/src/coreclr/vm/comtoclrcall.cpp +++ b/src/coreclr/vm/comtoclrcall.cpp @@ -263,6 +263,8 @@ OBJECTREF COMToCLRGetObjectAndTarget_NonVirtual(ComCallWrapper * pWrap, MethodDe } CONTRACTL_END; + CONTRACT_VIOLATION(ThrowsViolation); + //NOTE: No need to optimize for stub dispatch since non-virtuals are retrieved quickly. *ppManagedTargetOut = pRealMD->GetSingleCallableAddrOfCode(); @@ -850,6 +852,7 @@ void ComCallMethodDesc::InitMethod(MethodDesc *pMD, MethodDesc *pInterfaceMD) GC_TRIGGERS; MODE_ANY; PRECONDITION(CheckPointer(pMD)); + PRECONDITION(!pMD->IsAsyncMethod()); } CONTRACTL_END; @@ -974,6 +977,7 @@ void ComCallMethodDesc::InitNativeInfo() else { MethodDesc *pMD = GetCallMethodDesc(); + _ASSERTE(!pMD->IsAsyncMethod()); // Async methods should never have a ComCallMethodDesc. #ifdef _DEBUG LPCUTF8 szDebugName = pMD->m_pszDebugMethodName; @@ -985,9 +989,6 @@ void ComCallMethodDesc::InitNativeInfo() MethodTable * pMT = pMD->GetMethodTable(); IMDInternalImport * pInternalImport = pMT->GetMDImport(); - // TODO: (async) revisit and examine if this needs to be supported somehow - if (pMD->IsAsyncMethod()) - ThrowHR(COR_E_NOTSUPPORTED); mdMethodDef md = pMD->GetMemberDef(); diff --git a/src/coreclr/vm/dispatchinfo.cpp b/src/coreclr/vm/dispatchinfo.cpp index 57043fc89d9922..6c03246ed49528 100644 --- a/src/coreclr/vm/dispatchinfo.cpp +++ b/src/coreclr/vm/dispatchinfo.cpp @@ -450,9 +450,9 @@ ComMTMethodProps * DispatchMemberInfo::GetMemberProps(OBJECTREF MemberInfoObj, C MethodDesc* pMeth = (MethodDesc*) getMethodHandle.Call_RetLPVOID(&GetMethodHandleArg); if (pMeth) { - // TODO: (async) revisit and examine if this needs to be supported somehow + // We don't expose runtime-async methods via IDispatch. if (pMeth->IsAsyncMethod()) - ThrowHR(COR_E_NOTSUPPORTED); + RETURN NULL; pMemberProps = pMemberMap->GetMethodProps(pMeth->GetMemberDef(), pMeth->GetModule()); } @@ -830,15 +830,12 @@ void DispatchMemberInfo::SetUpMethodMarshalerInfo(MethodDesc *pMD, BOOL bReturnV GC_TRIGGERS; MODE_ANY; PRECONDITION(CheckPointer(pMD)); + PRECONDITION(!pMD->IsAsyncMethod()); } CONTRACTL_END; GCX_PREEMP(); - // TODO: (async) revisit and examine if this needs to be supported somehow - if (pMD->IsAsyncMethod()) - ThrowHR(COR_E_NOTSUPPORTED); - MetaSig msig(pMD); LPCSTR szName; USHORT usSequence; @@ -2584,7 +2581,7 @@ bool DispatchInfo::IsPropertyAccessorVisible(bool fIsSetter, OBJECTREF* pMemberI // Check to see if the new method is a property accessor. mdToken tkMember = mdTokenNil; - // TODO: (async) revisit and examine if this needs to be supported somehow + // Runtime-async property accessors are not visible from COM if (pMDForProperty->IsAsyncVariantMethod()) { return false; diff --git a/src/coreclr/vm/interoputil.cpp b/src/coreclr/vm/interoputil.cpp index 049e8854c13afa..9738b463de1629 100644 --- a/src/coreclr/vm/interoputil.cpp +++ b/src/coreclr/vm/interoputil.cpp @@ -2258,9 +2258,9 @@ ULONG GetStringizedClassItfDef(TypeHandle InterfaceType, CQuickArray &rDef { pDeclaringMT = pProps->pMeth->GetMethodTable(); tkMb = pProps->pMeth->GetMemberDef(); - // TODO: (async) revisit and examine if this needs to be supported somehow - if (pProps->pMeth->IsAsyncMethod()) - ThrowHR(COR_E_NOTSUPPORTED); + + // ComMTMemberInfoMap should not contain any async methods. + _ASSERTE(!pProps->pMeth->IsAsyncMethod()); cbCur = GetStringizedMethodDef(pDeclaringMT, tkMb, rDef, cbCur); } @@ -2472,7 +2472,7 @@ BOOL IsMethodVisibleFromCom(MethodDesc *pMD) mdProperty pd; LPCUTF8 pPropName; ULONG uSemantic; - // TODO: (async) revisit and examine if this needs to be supported somehow + // Async methods are not visible from COM. if (pMD->IsAsyncMethod()) return false; diff --git a/src/coreclr/vm/method.cpp b/src/coreclr/vm/method.cpp index dda289de1ed9f2..3858ce0eb438e9 100644 --- a/src/coreclr/vm/method.cpp +++ b/src/coreclr/vm/method.cpp @@ -1357,6 +1357,7 @@ WORD MethodDesc::GetComSlot() THROWS; GC_NOTRIGGER; FORBID_FAULT; + PRECONDITION(!IsAsyncMethod()); } CONTRACTL_END diff --git a/src/coreclr/vm/methodtablebuilder.cpp b/src/coreclr/vm/methodtablebuilder.cpp index 99d20f6afa4e57..a55ed3964c0f32 100644 --- a/src/coreclr/vm/methodtablebuilder.cpp +++ b/src/coreclr/vm/methodtablebuilder.cpp @@ -3431,6 +3431,17 @@ MethodTableBuilder::EnumerateClassMethods() pNewMemberSignature[offsetOfAsyncDetails] = ELEMENT_TYPE_CMOD_REQD; } + MethodClassification asyncVariantType = type; +#ifdef FEATURE_COMINTEROP + if (type == mcComInterop) + { + // For COM interop methods, + // we don't want to treat the async variant as a COM Interop method + // (as it isn't, it's a transient IL method). + asyncVariantType = mcIL; + } +#endif // FEATURE_COMINTEROP + Signature newMemberSig(pNewMemberSignature, cAsyncThunkMemberSignature); pNewMethod = new (GetStackingAllocator()) bmtMDMethod( bmtInternal->pType, @@ -3440,11 +3451,23 @@ MethodTableBuilder::EnumerateClassMethods() dwMethodRVA, newMemberSig, asyncFlags, - type, + asyncVariantType, implType); pNewMethod->SetAsyncOtherVariant(pDeclaredMethod); pDeclaredMethod->SetAsyncOtherVariant(pNewMethod); + +#ifdef FEATURE_COMINTEROP + // We only ever include one of the two async variants (whichever doesn't have the async calling convention) + // Record an excluded method here in the COM VTable. + EnsureOptionalFieldsAreAllocated(GetHalfBakedClass(), m_pAllocMemTracker, GetLoaderAllocator()->GetLowFrequencyHeap()); + if (GetHalfBakedClass()->GetSparseCOMInteropVTableMap() == NULL) + GetHalfBakedClass()->SetSparseCOMInteropVTableMap(new SparseVTableMap()); + + GetHalfBakedClass()->GetSparseCOMInteropVTableMap()->RecordExcludedMethod((WORD)NumDeclaredMethods()); + + bmtProp->fSparse = true; +#endif // FEATURE_COMINTEROP } bmtMethod->AddDeclaredMethod(pNewMethod); diff --git a/src/tests/Interop/CMakeLists.txt b/src/tests/Interop/CMakeLists.txt index f83f5c1c6bd568..b28f6629188dbb 100644 --- a/src/tests/Interop/CMakeLists.txt +++ b/src/tests/Interop/CMakeLists.txt @@ -86,6 +86,7 @@ if(CLR_CMAKE_TARGET_WIN32) add_subdirectory(COM/NativeClients/Dispatch) add_subdirectory(COM/NativeClients/Events) add_subdirectory(COM/NativeClients/MiscTypes) + add_subdirectory(COM/RuntimeAsync) # IJW isn't supported on ARM64 if(NOT CLR_CMAKE_HOST_ARCH_ARM64) diff --git a/src/tests/Interop/COM/RuntimeAsync/CMakeLists.txt b/src/tests/Interop/COM/RuntimeAsync/CMakeLists.txt new file mode 100644 index 00000000000000..74652650691256 --- /dev/null +++ b/src/tests/Interop/COM/RuntimeAsync/CMakeLists.txt @@ -0,0 +1,10 @@ +include_directories( ${INC_PLATFORM_DIR} ) +include_directories("../ServerContracts" ) +include_directories("../NativeServer" ) +include_directories("../NativeClients") +set(SOURCES + RuntimeAsyncNative.cpp) + +# add the executable +add_library (RuntimeAsyncNative SHARED ${SOURCES}) +target_link_libraries(RuntimeAsyncNative PRIVATE ${LINK_LIBRARIES_ADDITIONAL}) diff --git a/src/tests/Interop/COM/RuntimeAsync/CompilerAsync.csproj b/src/tests/Interop/COM/RuntimeAsync/CompilerAsync.csproj new file mode 100644 index 00000000000000..7b1576fda6c546 --- /dev/null +++ b/src/tests/Interop/COM/RuntimeAsync/CompilerAsync.csproj @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/src/tests/Interop/COM/RuntimeAsync/RuntimeAsync.cs b/src/tests/Interop/COM/RuntimeAsync/RuntimeAsync.cs new file mode 100644 index 00000000000000..529a4dec715ca3 --- /dev/null +++ b/src/tests/Interop/COM/RuntimeAsync/RuntimeAsync.cs @@ -0,0 +1,116 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using System.Threading.Tasks; +using Xunit; + +[ConditionalClass(typeof(TestLibrary.PlatformDetection), nameof(TestLibrary.PlatformDetection.IsBuiltInComEnabled))] +public class RuntimeAsyncBuiltInCom +{ + public const int ExpectedIntValue = 42; + public const float ExpectedClassFloatValue = 3.14f; + public const float ExpectedInterfaceFloatValue = 2.71f; + + [Fact] + public static void RuntimeAsyncThunksDoNotModifyCcwVtable() + { + ExposedToCom obj = new(); + Assert.True(RuntimeAsyncNative.ValidateSlotLayoutForDefaultInterface(obj, ExpectedIntValue, ExpectedClassFloatValue)); + Assert.True(RuntimeAsyncNative.ValidateSlotLayoutForInterface(obj, ExpectedInterfaceFloatValue)); + } + + [Fact] + public static void RuntimeAsyncDoNotModifyRcwVtable() + { + using (ComActivationHelpers.RegisterTypeForActivation()) + { + var myObjectType = Type.GetTypeFromCLSID(typeof(TaskComServer).GUID, throwOnError: true)!; + object obj = Activator.CreateInstance(myObjectType)!; + ITaskComServer_Imported comObject = (ITaskComServer_Imported)obj; + TestAsyncMethod(comObject).GetAwaiter().GetResult(); + + Assert.Equal(TaskComServer.ExpectedValue, comObject.GetValue()); + + static async Task TestAsyncMethod(ITaskComServer_Imported obj) + { + await obj.GetTask(); + } + } + } + + [Fact] + public static void IDispatchCallInvokesCorrectMethod() + { + using (ComActivationHelpers.RegisterTypeForActivation()) + { + var myObjectType = Type.GetTypeFromCLSID(typeof(TaskComServer).GUID, throwOnError: true)!; + object obj = Activator.CreateInstance(myObjectType)!; + ITaskComServer_AsDispatchOnly comObject = (ITaskComServer_AsDispatchOnly)obj; + TestAsyncMethod(comObject).GetAwaiter().GetResult(); + + Assert.Equal(TaskComServer.ExpectedValue, comObject.GetValue()); + + static async Task TestAsyncMethod(ITaskComServer_AsDispatchOnly obj) + { + await obj.GetTask(); + } + } + } +} + +public static class RuntimeAsyncNative +{ + [DllImport("RuntimeAsyncNative")] + [return: MarshalAs(UnmanagedType.U1)] + public static extern bool ValidateSlotLayoutForDefaultInterface([MarshalAs(UnmanagedType.Interface)] object comObject, int expectedIntValue, float expectedFloatValue); + + [DllImport("RuntimeAsyncNative")] + [return: MarshalAs(UnmanagedType.U1)] + public static extern bool ValidateSlotLayoutForInterface([MarshalAs(UnmanagedType.Interface)] object comObject, float expectedFloatValue); +} + +[ComVisible(true)] +[InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] +public interface IExposedToComInterface +{ + Task AsyncMethodOnInterface(); + + float FloatMethodOnInterface(); +} + +[ComVisible(true)] +[ClassInterface(ClassInterfaceType.AutoDual)] +public class ExposedToCom : IExposedToComInterface +{ + public int MyMethod() + { + return RuntimeAsyncBuiltInCom.ExpectedIntValue; + } + + public async Task MyAsyncMethod() + { + return await Task.FromResult(1); + } + + public async Task MyAsyncMethod2() + { + await Task.Run(() => { }); + } + + public float MyFloatMethod() + { + return RuntimeAsyncBuiltInCom.ExpectedClassFloatValue; + } + + async Task IExposedToComInterface.AsyncMethodOnInterface() + { + await Task.Run(() => { }); + } + + float IExposedToComInterface.FloatMethodOnInterface() + { + return RuntimeAsyncBuiltInCom.ExpectedInterfaceFloatValue; + } +} diff --git a/src/tests/Interop/COM/RuntimeAsync/RuntimeAsync.csproj b/src/tests/Interop/COM/RuntimeAsync/RuntimeAsync.csproj new file mode 100644 index 00000000000000..0744068de54384 --- /dev/null +++ b/src/tests/Interop/COM/RuntimeAsync/RuntimeAsync.csproj @@ -0,0 +1,20 @@ + + + + true + $(NoWarn);SYSLIB5007 + $(Features);runtime-async=on + + + + + + + + + + + + + + diff --git a/src/tests/Interop/COM/RuntimeAsync/RuntimeAsyncNative.cpp b/src/tests/Interop/COM/RuntimeAsync/RuntimeAsyncNative.cpp new file mode 100644 index 00000000000000..d4e2b4a41cc942 --- /dev/null +++ b/src/tests/Interop/COM/RuntimeAsync/RuntimeAsyncNative.cpp @@ -0,0 +1,81 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#include +#include +#include +#include +#include + +// COM headers +#include +#include + +extern "C" DLL_EXPORT bool STDMETHODCALLTYPE ValidateSlotLayoutForDefaultInterface(IUnknown* pUnk, int expectedIntValue, float expectedFloatValue) +{ + ComSmartPtr spUnk(pUnk); + + ComSmartPtr spDefaultInterface; + HRESULT hr = spUnk->QueryInterface(&spDefaultInterface); + if (FAILED(hr)) + { + printf("QueryInterface for IClassDefaultInterfaceExposedToCom failed with hr=0x%08X\n", hr); + return false; + } + + int intValue = 0; + float floatValue = 0.0f; + if (FAILED(spDefaultInterface->MyMethod(&intValue))) + { + printf("MyMethod failed\n"); + return false; + } + + if (intValue != expectedIntValue) + { + printf("MyMethod returned intValue=%d, expected %d\n", intValue, expectedIntValue); + return false; + } + + if (FAILED(spDefaultInterface->MyFloatMethod(&floatValue))) + { + printf("MyFloatMethod failed\n"); + return false; + } + + if (floatValue != expectedFloatValue) + { + printf("MyFloatMethod returned floatValue=%f, expected %f\n", floatValue, expectedFloatValue); + return false; + } + + return true; +} + +extern "C" DLL_EXPORT bool STDMETHODCALLTYPE ValidateSlotLayoutForInterface(IUnknown* pUnk, float expectedFloatValue) +{ + ComSmartPtr spUnk(pUnk); + + ComSmartPtr spInterface; + HRESULT hr = spUnk->QueryInterface(&spInterface); + if (FAILED(hr)) + { + printf("QueryInterface for IInterfaceExposedToCom failed with hr=0x%08X\n", hr); + return false; + } + + float floatValue = 0.0f; + if (FAILED(spInterface->FloatMethodOnInterface(&floatValue))) + { + printf("FloatMethodOnInterface failed\n"); + return false; + } + + if (floatValue != expectedFloatValue) + { + printf("FloatMethodOnInterface returned floatValue=%f, expected %f\n", floatValue, expectedFloatValue); + return false; + } + + return true; +} diff --git a/src/tests/Interop/COM/RuntimeAsync/TaskComServer.cs b/src/tests/Interop/COM/RuntimeAsync/TaskComServer.cs new file mode 100644 index 00000000000000..4ef3cef2c8b0a7 --- /dev/null +++ b/src/tests/Interop/COM/RuntimeAsync/TaskComServer.cs @@ -0,0 +1,54 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// + +using System; +using System.Runtime.InteropServices; +using System.Threading.Tasks; + +[ComVisible(true)] +[Guid("923dd7d3-a3c0-4fe2-9222-170f9d983724")] +public interface ITaskComServer +{ + Task GetTask(); + int GetValue(); +} + +[ComVisible(true)] +[Guid("9ad32e71-347f-47e7-8fe2-096975e65b00")] +public class TaskComServer : ITaskComServer +{ + public const int ExpectedValue = 12345; + + public Task GetTask() + { + // Create a definitely-not-generic Task. + // Non-blittable generics can't be marshalled, + // so we can't have Task. + Task t = new(() => {}); + t.Start(); + return t; + } + + public int GetValue() + { + return ExpectedValue; + } +} + +[ComImport] +[Guid("923dd7d3-a3c0-4fe2-9222-170f9d983724")] +public interface ITaskComServer_Imported +{ + Task GetTask(); + int GetValue(); +} + +[ComImport] +[Guid("923dd7d3-a3c0-4fe2-9222-170f9d983724")] +[InterfaceType(ComInterfaceType.InterfaceIsIDispatch)] +public interface ITaskComServer_AsDispatchOnly +{ + Task GetTask(); + int GetValue(); +} diff --git a/src/tests/Interop/COM/ServerContracts/Server.Contracts.h b/src/tests/Interop/COM/ServerContracts/Server.Contracts.h index b73c0c54b8b2f9..76301fe0ae2b40 100644 --- a/src/tests/Interop/COM/ServerContracts/Server.Contracts.h +++ b/src/tests/Interop/COM/ServerContracts/Server.Contracts.h @@ -546,4 +546,30 @@ ITrackMyLifetimeTesting : IUnknown virtual HRESULT STDMETHODCALLTYPE GetAllocationCountCallback(_Outptr_ void** fptr) = 0; }; +// IIDs for the below types are generated by the runtime. +// They are not randomly chosen. +struct __declspec(uuid("AA0540BD-56C8-399B-B653-B787A33827F3")) +IClassDefaultInterfaceExposedToCom : IDispatch +{ + // _Object members from IClassX + virtual HRESULT STDMETHODCALLTYPE ToString(_Out_ _Ret_ BSTR* pRetVal) = 0; + virtual HRESULT STDMETHODCALLTYPE Equals(_In_ IUnknown* other, _Out_ _Ret_ VARIANT_BOOL* pRetVal) = 0; + virtual HRESULT STDMETHODCALLTYPE GetHashCode(_Out_ _Ret_ int* pRetVal) = 0; + virtual HRESULT STDMETHODCALLTYPE GetType(_Out_ _Ret_ IUnknown* pRetVal) = 0; + + // ExposedToCom members + virtual HRESULT STDMETHODCALLTYPE MyMethod(_Out_ _Ret_ int* pRetVal) = 0; + virtual HRESULT STDMETHODCALLTYPE MyAsyncMethod(_Out_ _Ret_ IUnknown** pRetVal) = 0; + virtual HRESULT STDMETHODCALLTYPE MyAsyncMethod2(_Out_ _Ret_ IUnknown** pRetVal) = 0; + virtual HRESULT STDMETHODCALLTYPE MyFloatMethod(_Out_ _Ret_ float* pRetVal) = 0; +}; + + +struct __declspec(uuid("1BE53E57-C3E4-3A1D-A751-F12207A0F8A8")) +IInterfaceExposedToCom : IUnknown +{ + virtual HRESULT STDMETHODCALLTYPE AsyncMethodOnInterface(_Out_ _Ret_ IUnknown** pRetVal) = 0; + virtual HRESULT STDMETHODCALLTYPE FloatMethodOnInterface(_Out_ _Ret_ float* pRetVal) = 0; +}; + #pragma pack(pop)