Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Test the Marshal.QueryInterface, Marshal.AddRef, and Marshal.Release
using ComWrappers.
  • Loading branch information
AaronRobinsonMSFT committed Jul 30, 2021
commit df278421323b08e2d6631a02fd87a9ddf37aa224
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
<Compile Include="System\Runtime\InteropServices\Marshal\ZeroFreeGlobalAllocAnsiTests.cs" />
<Compile Include="System\Runtime\InteropServices\Marshal\ZeroFreeGlobalAllocUnicodeTests.cs" />
<Compile Include="System\Runtime\InteropServices\Marshal\Common\CommonTypes.cs" />
<Compile Include="System\Runtime\InteropServices\Marshal\Common\COMWrappersImpl.cs" />
<Compile Include="System\Runtime\InteropServices\Marshal\Common\CommonTypes.Windows.cs" Condition="'$(TargetsWindows)' == 'true'" />
<Compile Include="System\Runtime\InteropServices\Marshal\Common\Variant.cs" />
<Compile Include="System\Runtime\InteropServices\Marshal\ReadWrite\ByteTests.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Runtime.InteropServices.Tests.Common;
using Xunit;

namespace System.Runtime.InteropServices.Tests
{
public class AddRefTests
{
[Fact]
[PlatformSpecific(TestPlatforms.Windows)]
public void AddRef_ValidPointer_Success()
{
IntPtr iUnknown = Marshal.GetIUnknownForObject(new object());
var cw = new ComWrappersImpl();
IntPtr iUnknown = cw.GetOrCreateComInterfaceForObject(new object(), CreateComInterfaceFlags.None);
try
{
Assert.Equal(2, Marshal.AddRef(iUnknown));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections;
using System.Runtime.CompilerServices;
using Xunit;

namespace System.Runtime.InteropServices.Tests.Common
{
public class ComWrappersImpl : ComWrappers
{
public const string IID_TestQueryInterface = "1F906666-B388-4729-B78C-826BC5FD4245";

protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
{
Assert.Equal(CreateComInterfaceFlags.None, flags);

IntPtr fpQueryInterface = default;
IntPtr fpAddRef = default;
IntPtr fpRelease = default;
ComWrappers.GetIUnknownImpl(out fpQueryInterface, out fpAddRef, out fpRelease);

var vtblRaw = (IntPtr*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComWrappersImpl), IntPtr.Size * 3);
vtblRaw[0] = fpQueryInterface;
vtblRaw[1] = fpAddRef;
vtblRaw[2] = fpRelease;

var entryRaw = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComWrappersImpl), sizeof(ComInterfaceEntry));
entryRaw->IID = new Guid(IID_TestQueryInterface);
entryRaw->Vtable = (IntPtr)vtblRaw;

count = 1;
return entryRaw;
}

protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flag)
=> throw new NotImplementedException();

protected override void ReleaseObjects(IEnumerable objects)
=> throw new NotImplementedException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace System.Runtime.InteropServices.Tests
{
public partial class QueryInterfaceTests
{
public const string IID_IDISPATCH = "00020400-0000-0000-C000-000000000046";
public const string IID_IINSPECTABLE = "AF86E2E0-B12D-4c6a-9C5A-D7AA65101E90";

public static IEnumerable<object[]> QueryInterface_ValidComObjectInterface_TestData()
Expand Down Expand Up @@ -45,7 +46,25 @@ public static IEnumerable<object[]> QueryInterface_ValidComObjectInterface_TestD
[MemberData(nameof(QueryInterface_ValidComObjectInterface_TestData))]
public void QueryInterface_ValidComObjectInterface_Success(object o, string iidString)
{
QueryInterface_ValidInterface_Success(o, iidString);
IntPtr ptr = Marshal.GetIUnknownForObject(o);
try
{
Guid guid = new Guid(iidString);
Assert.Equal(0, Marshal.QueryInterface(ptr, ref guid, out IntPtr ppv));
Assert.NotEqual(IntPtr.Zero, ppv);
try
{
Assert.Equal(new Guid(iidString), guid);
}
finally
{
Marshal.Release(ppv);
}
}
finally
{
Marshal.Release(ptr);
}
}

public static IEnumerable<object[]> QueryInterface_NoSuchComObjectInterface_TestData()
Expand Down Expand Up @@ -83,7 +102,18 @@ public static IEnumerable<object[]> QueryInterface_NoSuchComObjectInterface_Test
[MemberData(nameof(QueryInterface_NoSuchComObjectInterface_TestData))]
public void QueryInterface_NoSuchComObjectInterface_Success(object o, string iidString)
{
QueryInterface_NoSuchInterface_Success(o, iidString);
IntPtr ptr = Marshal.GetIUnknownForObject(o);
try
{
Guid iid = new Guid(iidString);
Assert.Equal(E_NOINTERFACE, Marshal.QueryInterface(ptr, ref iid, out IntPtr ppv));
Assert.Equal(IntPtr.Zero, ppv);
Assert.Equal(new Guid(iidString), iid);
}
finally
{
Marshal.Release(ptr);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,71 +12,27 @@ public partial class QueryInterfaceTests
{
public const int E_NOINTERFACE = unchecked((int)0x80004002);
public const string IID_IUNKNOWN = "00000000-0000-0000-C000-000000000046";
public const string IID_IDISPATCH = "00020400-0000-0000-C000-000000000046";

public static IEnumerable<object[]> QueryInterface_ValidInterface_TestData()
{
yield return new object[] { new object(), IID_IUNKNOWN };
yield return new object[] { new object(), IID_IDISPATCH };

yield return new object[] { 10, IID_IUNKNOWN };
if (!PlatformDetection.IsNetCore)
{
yield return new object[] { 10, IID_IDISPATCH };
}

yield return new object[] { "string", IID_IUNKNOWN };
if (!PlatformDetection.IsNetCore)
{
yield return new object[] { "string", IID_IDISPATCH };
}

yield return new object[] { new NonGenericClass(), IID_IUNKNOWN };
if (!PlatformDetection.IsNetCore)
{
yield return new object[] { new NonGenericClass(), IID_IDISPATCH };
}
yield return new object[] { new GenericClass<string>(), IID_IUNKNOWN };

yield return new object[] { new NonGenericStruct(), IID_IUNKNOWN };
if (!PlatformDetection.IsNetCore)
{
yield return new object[] { new NonGenericStruct(), IID_IDISPATCH };
}
yield return new object[] { new GenericStruct<string>(), IID_IUNKNOWN };

yield return new object[] { Int32Enum.Value1, IID_IUNKNOWN };
if (!PlatformDetection.IsNetCore)
{
yield return new object[] { Int32Enum.Value1, IID_IDISPATCH };
}

yield return new object[] { new int[] { 10 }, IID_IUNKNOWN };
yield return new object[] { new int[][] { new int[] { 10 } }, IID_IUNKNOWN };
yield return new object[] { new int[,] { { 10 } }, IID_IUNKNOWN };

MethodInfo method = typeof(GetObjectForIUnknownTests).GetMethod(nameof(NonGenericMethod), BindingFlags.NonPublic | BindingFlags.Static);
Delegate d = method.CreateDelegate(typeof(NonGenericDelegate));
yield return new object[] { d, IID_IUNKNOWN };
yield return new object[] { d, IID_IDISPATCH };

yield return new object[] { new KeyValuePair<string, int>("key", 10), IID_IUNKNOWN };
yield return new object[] { new object(), ComWrappersImpl.IID_TestQueryInterface };
}

[Theory]
[MemberData(nameof(QueryInterface_ValidInterface_TestData))]
[PlatformSpecific(TestPlatforms.Windows)]
public void QueryInterface_ValidInterface_Success(object o, string guid)
public void QueryInterface_ValidInterface_Success(object o, string iidString)
{
IntPtr ptr = Marshal.GetIUnknownForObject(o);
var cw = new ComWrappersImpl();
IntPtr ptr = cw.GetOrCreateComInterfaceForObject(o, CreateComInterfaceFlags.None);
try
{
Guid iidString = new Guid(guid);
Assert.Equal(0, Marshal.QueryInterface(ptr, ref iidString, out IntPtr ppv));
Guid guid = new Guid(iidString);
Assert.Equal(0, Marshal.QueryInterface(ptr, ref guid, out IntPtr ppv));
Assert.NotEqual(IntPtr.Zero, ppv);
try
{
Assert.Equal(new Guid(guid), iidString);
Assert.Equal(new Guid(iidString), guid);
}
finally
{
Expand All @@ -93,23 +49,14 @@ public static IEnumerable<object[]> QueryInterface_NoSuchInterface_TestData()
{
yield return new object[] { new object(), Guid.Empty.ToString() };
yield return new object[] { new object(), "927971f5-0939-11d1-8be1-00c04fd8d503" };

yield return new object[] { new int[] { 10 }, IID_IDISPATCH };
yield return new object[] { new int[][] { new int[] { 10 } }, IID_IDISPATCH };
yield return new object[] { new int[,] { { 10 } }, IID_IDISPATCH };

yield return new object[] { new GenericClass<string>(), IID_IDISPATCH };
yield return new object[] { new Dictionary<string, int>(), IID_IDISPATCH };
yield return new object[] { new GenericStruct<string>(), IID_IDISPATCH };
yield return new object[] { new KeyValuePair<string, int>(), IID_IDISPATCH };
}

[Theory]
[MemberData(nameof(QueryInterface_NoSuchInterface_TestData))]
[PlatformSpecific(TestPlatforms.Windows)]
public void QueryInterface_NoSuchInterface_Success(object o, string iidString)
{
IntPtr ptr = Marshal.GetIUnknownForObject(o);
var cw = new ComWrappersImpl();
IntPtr ptr = cw.GetOrCreateComInterfaceForObject(o, CreateComInterfaceFlags.None);
try
{
Guid iid = new Guid(iidString);
Expand All @@ -129,10 +76,5 @@ public void QueryInterface_ZeroPointer_ThrowsArgumentNullException()
Guid iid = Guid.Empty;
AssertExtensions.Throws<ArgumentNullException>("pUnk", () => Marshal.QueryInterface(IntPtr.Zero, ref iid, out IntPtr ppv));
}

private static void NonGenericMethod(int i) { }
public delegate void NonGenericDelegate(int i);

public enum Int32Enum : int { Value1, Value2 }
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Runtime.InteropServices.Tests.Common;
using Xunit;

namespace System.Runtime.InteropServices.Tests
{
public class ReleaseTests
{
[Fact]
[PlatformSpecific(TestPlatforms.Windows)]
public void Release_ValidPointer_Success()
{
IntPtr iUnknown = Marshal.GetIUnknownForObject(new object());
var cw = new ComWrappersImpl();
IntPtr iUnknown = cw.GetOrCreateComInterfaceForObject(new object(), CreateComInterfaceFlags.None);
try
{
Marshal.AddRef(iUnknown);
Expand Down