Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add test for weak reference + aggregation with native weak reference …
…impl.
  • Loading branch information
jkoritzinsky committed Nov 5, 2021
commit b74ea9a8cc483adeefcaae55cb2db8b7b1d38a41
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,122 @@ namespace
return UnknownImpl::DoRelease();
}
};

struct WeakReferenceSource : public IWeakReferenceSource, public IInspectable
{
private:
IUnknown* _outerUnknown;
ComSmartPtr<WeakReference> _weakReference;
public:
WeakReferenceSource(IUnknown* outerUnknown)
:_outerUnknown(outerUnknown),
_weakReference(new WeakReference(this, 1))
{
}

STDMETHOD(GetWeakReference)(IWeakReference** ppWeakReference)
{
_weakReference->AddRef();
*ppWeakReference = _weakReference;
return S_OK;
}

STDMETHOD(QueryInterface)(
/* [in] */ REFIID riid,
/* [iid_is][out] */ void ** ppvObject)
{
if (riid == __uuidof(IWeakReferenceSource))
{
*ppvObject = static_cast<IWeakReferenceSource*>(this);
_weakReference->AddStrongRef();
return S_OK;
}
return _outerUnknown->QueryInterface(riid, ppvObject);
}
STDMETHOD_(ULONG, AddRef)(void)
{
return _weakReference->AddStrongRef();
}
STDMETHOD_(ULONG, Release)(void)
{
return _weakReference->ReleaseStrongRef();
}

STDMETHOD(GetRuntimeClassName)(HSTRING* pRuntimeClassName)
{
return E_NOTIMPL;
}

STDMETHOD(GetIids)(
ULONG *iidCount,
IID **iids)
{
return E_NOTIMPL;
}

STDMETHOD(GetTrustLevel)(TrustLevel *trustLevel)
{
*trustLevel = FullTrust;
return S_OK;
}
};

struct AggregatedWeakReferenceSource : IInspectable
{
private:
IUnknown* _outerUnknown;
ComSmartPtr<WeakReferenceSource> _weakReference;
public:
AggregatedWeakReferenceSource(IUnknown* outerUnknown)
:_outerUnknown(outerUnknown),
_weakReference(new WeakReferenceSource(outerUnknown))
{
}

STDMETHOD(GetRuntimeClassName)(HSTRING* pRuntimeClassName)
{
return E_NOTIMPL;
}

STDMETHOD(GetIids)(
ULONG *iidCount,
IID **iids)
{
return E_NOTIMPL;
}

STDMETHOD(GetTrustLevel)(TrustLevel *trustLevel)
{
*trustLevel = FullTrust;
return S_OK;
}

STDMETHOD(QueryInterface)(
/* [in] */ REFIID riid,
/* [iid_is][out] */ void ** ppvObject)
{
if (riid == __uuidof(IWeakReferenceSource))
{
return _weakReference->QueryInterface(riid, ppvObject);
}
return _outerUnknown->QueryInterface(riid, ppvObject);
}
STDMETHOD_(ULONG, AddRef)(void)
{
return _outerUnknown->AddRef();
}
STDMETHOD_(ULONG, Release)(void)
{
return _outerUnknown->Release();
}
};
}
extern "C" DLL_EXPORT WeakReferencableObject* STDMETHODCALLTYPE CreateWeakReferencableObject()
{
return new WeakReferencableObject();
}

extern "C" DLL_EXPORT AggregatedWeakReferenceSource* STDMETHODCALLTYPE CreateAggregatedWeakReferenceObject(IUnknown* pOuter)
{
return new AggregatedWeakReferenceSource(pOuter);
}
113 changes: 81 additions & 32 deletions src/tests/Interop/COM/ComWrappers/WeakReference/WeakReferenceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ static class WeakReferenceNative
{
[DllImport(nameof(WeakReferenceNative))]
public static extern IntPtr CreateWeakReferencableObject();

[DllImport(nameof(WeakReferenceNative))]
public static extern IntPtr CreateAggregatedWeakReferenceObject(IntPtr outer);
}

public struct VtblPtr
Expand All @@ -28,71 +31,96 @@ public enum WrapperRegistration
Marshalling,
}

public class WeakReferenceableWrapper
public unsafe class WeakReferenceableWrapper
{
private struct Vtbl
{
public IntPtr QueryInterface;
public _AddRef AddRef;
public _Release Release;
public delegate*<IntPtr, Guid*, IntPtr*, int> QueryInterface;
public delegate*<IntPtr, int> AddRef;
public delegate*<IntPtr, int> Release;
}

private delegate int _AddRef(IntPtr This);
private delegate int _Release(IntPtr This);

private readonly IntPtr instance;
private readonly Vtbl vtable;
private readonly bool releaseInFinalizer;

public WrapperRegistration Registration { get; }

public WeakReferenceableWrapper(IntPtr instance, WrapperRegistration reg)
public WeakReferenceableWrapper(IntPtr instance, WrapperRegistration reg, bool releaseInFinalizer = true)
{
var inst = Marshal.PtrToStructure<VtblPtr>(instance);
this.vtable = Marshal.PtrToStructure<Vtbl>(inst.Vtbl);
this.instance = instance;
this.releaseInFinalizer = releaseInFinalizer;
Registration = reg;
}

public int QueryInterface(Guid iid, out IntPtr ptr)
{
fixed(IntPtr* ppv = &ptr)
{
return this.vtable.QueryInterface(this.instance, &iid, ppv);
}
}

~WeakReferenceableWrapper()
{
if (this.instance != IntPtr.Zero)
if (this.instance != IntPtr.Zero && this.releaseInFinalizer)
{
this.vtable.Release(this.instance);
}
}
}

class Program
class DerivedObject : ICustomQueryInterface
{
class TestComWrappers : ComWrappers
private WeakReferenceableWrapper inner;
public DerivedObject(TestComWrappers comWrappersInstance)
{
public WrapperRegistration Registration { get; }
IntPtr innerInstance = WeakReferenceNative.CreateAggregatedWeakReferenceObject(
comWrappersInstance.GetOrCreateComInterfaceForObject(this, CreateComInterfaceFlags.None));
inner = new WeakReferenceableWrapper(innerInstance, comWrappersInstance.Registration, releaseInFinalizer: false);
comWrappersInstance.GetOrRegisterObjectForComInstance(innerInstance, CreateObjectFlags.Aggregation, this);
}

public TestComWrappers(WrapperRegistration reg = WrapperRegistration.Local)
{
Registration = reg;
}
public CustomQueryInterfaceResult GetInterface(ref Guid iid, out IntPtr ppv)
{
return inner.QueryInterface(iid, out ppv) == 0 ? CustomQueryInterfaceResult.Handled : CustomQueryInterfaceResult.Failed;
}
}

protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
{
count = 0;
return null;
}
class TestComWrappers : ComWrappers
{
public WrapperRegistration Registration { get; }

protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flag)
{
Marshal.AddRef(externalComObject);
return new WeakReferenceableWrapper(externalComObject, Registration);
}
public TestComWrappers(WrapperRegistration reg = WrapperRegistration.Local)
{
Registration = reg;
}

protected override void ReleaseObjects(IEnumerable objects)
{
}
protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
{
count = 0;
return null;
}

protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flag)
{
Marshal.AddRef(externalComObject);
return new WeakReferenceableWrapper(externalComObject, Registration);
}

public static readonly TestComWrappers TrackerSupportInstance = new TestComWrappers(WrapperRegistration.TrackerSupport);
public static readonly TestComWrappers MarshallingInstance = new TestComWrappers(WrapperRegistration.Marshalling);
protected override void ReleaseObjects(IEnumerable objects)
{
}

public static readonly TestComWrappers TrackerSupportInstance = new TestComWrappers(WrapperRegistration.TrackerSupport);
public static readonly TestComWrappers MarshallingInstance = new TestComWrappers(WrapperRegistration.Marshalling);
}

class Program
{

private static void ValidateWeakReferenceState(WeakReference<WeakReferenceableWrapper> wr, bool expectedIsAlive, TestComWrappers sourceWrappers = null)
{
WeakReferenceableWrapper target;
Expand Down Expand Up @@ -135,7 +163,7 @@ private static void ValidateNativeWeakReference(TestComWrappers cw)
// a global ComWrappers instance. If the RCW was created throug a local ComWrappers instance, the weak
// reference should be dead and stay dead once the RCW is collected.
bool supportsRehydration = cw.Registration != WrapperRegistration.Local;

Console.WriteLine($" -- Validate RCW recreation");
ValidateWeakReferenceState(weakRef, expectedIsAlive: supportsRehydration, cw);

Expand Down Expand Up @@ -221,6 +249,26 @@ bool HasTarget(WeakReference wr)
Assert.IsNull(weakRef.Target);
}

static void ValidateAggregatedWeakReference()
{
Console.WriteLine("Validate weak reference with aggregation.");
var (handle, weakRef) = GetWeakReference();

GC.Collect();
GC.WaitForPendingFinalizers();

Assert.IsNull(handle.Target);
Assert.IsFalse(weakRef.TryGetTarget(out _));

static (GCHandle handle, WeakReference<DerivedObject>) GetWeakReference()
{
DerivedObject obj = new DerivedObject(TestComWrappers.TrackerSupportInstance);
// We use an explicit weak GC handle here to enable us to validate that we are using "weak" GCHandle
// semantics with the weak reference.
return (GCHandle.Alloc(obj, GCHandleType.Weak), new WeakReference<DerivedObject>(obj));
}
}

static int Main(string[] doNotUse)
{
try
Expand All @@ -235,6 +283,7 @@ static int Main(string[] doNotUse)

ComWrappers.RegisterForTrackerSupport(TestComWrappers.TrackerSupportInstance);
ValidateGlobalInstanceTrackerSupport();
ValidateAggregatedWeakReference();

ValidateLocalInstance();
}
Expand Down