Skip to content

Commit 16f8dac

Browse files
committed
Apply suggestions from code review
1 parent b0346ab commit 16f8dac

File tree

1 file changed

+40
-47
lines changed

1 file changed

+40
-47
lines changed

src/WinRT.Runtime/Interop/IContextCallback.cs

Lines changed: 40 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System;
55
using System.Runtime.CompilerServices;
66
using System.Runtime.InteropServices;
7+
using System.Threading;
78
using WinRT;
89
using WinRT.Interop;
910

@@ -19,36 +20,14 @@ internal struct ComCallData
1920
}
2021

2122
#if NET && CsWinRT_LANG_11_FEATURES
22-
internal sealed unsafe class CallbackData
23+
internal unsafe struct CallbackData
2324
{
2425
[ThreadStatic]
25-
private static CallbackData TlsInstance;
26+
public static object PerThreadObject;
2627

2728
public delegate*<object, void> Callback;
28-
public object State;
29-
public GCHandle Handle;
30-
31-
private CallbackData()
32-
{
33-
// Create a handle to access the object from a native callback invoked on another thread.
34-
// The handle is weak to ensure that the object does not leak (or it would keep itself
35-
// alive). The target is guaranteed to be alive because callers will use 'GC.KeepAlive'.
36-
Handle = GCHandle.Alloc(this, GCHandleType.Weak);
37-
}
38-
39-
~CallbackData()
40-
{
41-
Handle.Free();
42-
}
43-
44-
[MethodImpl(MethodImplOptions.AggressiveInlining)]
45-
public static CallbackData GetOrCreate()
46-
{
47-
return TlsInstance ??= new CallbackData();
48-
}
29+
public object* StatePtr;
4930
}
50-
51-
5231
#endif
5332

5433
#if NET && CsWinRT_LANG_11_FEATURES
@@ -61,22 +40,19 @@ internal unsafe struct IContextCallbackVftbl
6140

6241
public static void ContextCallback(IntPtr contextCallbackPtr, delegate*<object, void> callback, delegate*<object, void> onFailCallback, object state)
6342
{
64-
ComCallData comCallData;
65-
comCallData.dwDispid = 0;
66-
comCallData.dwReserved = 0;
67-
68-
CallbackData callbackData = CallbackData.GetOrCreate();
69-
70-
comCallData.pUserDefined = GCHandle.ToIntPtr(callbackData.Handle);
71-
43+
// Native method that invokes the callback on the target context. The state object
44+
// is guaranteed to be pinned, so we can access it from a pointer. Note that the
45+
// object will be stored in a static field, and it will not be on the stack of the
46+
// original thread, so it's safe with respect to cross-thread access of managed objects.
47+
// See: https://github.com/dotnet/runtime/blob/main/docs/design/specs/Memory-model.md#cross-thread-access-to-local-variables.
7248
[UnmanagedCallersOnly]
7349
static int InvokeCallback(ComCallData* comCallData)
7450
{
7551
try
7652
{
77-
CallbackData callbackData = Unsafe.As<CallbackData>(GCHandle.FromIntPtr(comCallData->pUserDefined).Target);
53+
CallbackData* callbackData = (CallbackData*)comCallData->pUserDefined;
7854

79-
callbackData.Callback(callbackData.State);
55+
callbackData->Callback(*callbackData->StatePtr);
8056

8157
return 0; // S_OK
8258
}
@@ -86,20 +62,37 @@ static int InvokeCallback(ComCallData* comCallData)
8662
}
8763
}
8864

89-
Guid iid = IID.IID_ICallbackWithNoReentrancyToApplicationSTA;
65+
ComCallData comCallData;
66+
comCallData.dwDispid = 0;
67+
comCallData.dwReserved = 0;
68+
69+
CallbackData.PerThreadObject = state;
70+
71+
int hresult;
9072

91-
int hresult = (*(IContextCallbackVftbl**)contextCallbackPtr)->ContextCallback_4(
92-
contextCallbackPtr,
93-
(IntPtr)(delegate* unmanaged<ComCallData*, int>)&InvokeCallback,
94-
&comCallData,
95-
&iid,
96-
/* iMethod */ 5,
97-
IntPtr.Zero);
73+
fixed (object* statePtr = &CallbackData.PerThreadObject)
74+
{
75+
CallbackData callbackData;
76+
callbackData.Callback = callback;
77+
callbackData.StatePtr = statePtr;
78+
79+
Guid iid = IID.IID_ICallbackWithNoReentrancyToApplicationSTA;
80+
81+
// Add a memory barrier to be extra safe that the target thread will be able to see
82+
// the write we just did on 'PerThreadObject' with the state to pass to the callback.
83+
Thread.MemoryBarrier();
84+
85+
hresult = (*(IContextCallbackVftbl**)contextCallbackPtr)->ContextCallback_4(
86+
contextCallbackPtr,
87+
(IntPtr)(delegate* unmanaged<ComCallData*, int>)&InvokeCallback,
88+
&comCallData,
89+
&iid,
90+
/* iMethod */ 5,
91+
IntPtr.Zero);
92+
}
9893

99-
// This call is critical to ensure that the callback data is kept alive until we get here.
100-
// This prevents its finalizer to run (that finalizer would free the GC handle used in the
101-
// native callback to get back the target callback data that contains the dispatch parameters).
102-
GC.KeepAlive(callbackData);
94+
// Reset the static field to avoid keeping the state alive for longer
95+
Volatile.Write(ref CallbackData.PerThreadObject, null);
10396

10497
if (hresult < 0)
10598
{

0 commit comments

Comments
 (0)