diff --git a/src/System.Threading.ThreadPool/ref/System.Threading.ThreadPool.cs b/src/System.Threading.ThreadPool/ref/System.Threading.ThreadPool.cs index 7d1911d50af3..a975085ddd15 100644 --- a/src/System.Threading.ThreadPool/ref/System.Threading.ThreadPool.cs +++ b/src/System.Threading.ThreadPool/ref/System.Threading.ThreadPool.cs @@ -8,6 +8,10 @@ namespace System.Threading { + public interface IThreadPoolWorkItem + { + void Execute(); + } public sealed partial class RegisteredWaitHandle : System.MarshalByRefObject { internal RegisteredWaitHandle() { } @@ -33,6 +37,7 @@ public static partial class ThreadPool public static bool SetMinThreads(int workerThreads, int completionPortThreads) { throw null; } [System.CLSCompliantAttribute(false)] public static unsafe bool UnsafeQueueNativeOverlapped(System.Threading.NativeOverlapped* overlapped) { throw null; } + public static bool UnsafeQueueUserWorkItem(System.Threading.IThreadPoolWorkItem callBack, bool preferLocal) { throw null; } public static bool UnsafeQueueUserWorkItem(System.Threading.WaitCallback callBack, object state) { throw null; } public static System.Threading.RegisteredWaitHandle UnsafeRegisterWaitForSingleObject(System.Threading.WaitHandle waitObject, System.Threading.WaitOrTimerCallback callBack, object state, int millisecondsTimeOutInterval, bool executeOnlyOnce) { throw null; } public static System.Threading.RegisteredWaitHandle UnsafeRegisterWaitForSingleObject(System.Threading.WaitHandle waitObject, System.Threading.WaitOrTimerCallback callBack, object state, long millisecondsTimeOutInterval, bool executeOnlyOnce) { throw null; } diff --git a/src/System.Threading.ThreadPool/tests/ThreadPoolTests.netcoreapp.cs b/src/System.Threading.ThreadPool/tests/ThreadPoolTests.netcoreapp.cs index 4aac66daed38..da1f9b0a3287 100644 --- a/src/System.Threading.ThreadPool/tests/ThreadPoolTests.netcoreapp.cs +++ b/src/System.Threading.ThreadPool/tests/ThreadPoolTests.netcoreapp.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Linq; using System.Threading.Tasks; using Xunit; @@ -72,5 +73,83 @@ public async Task QueueUserWorkItem_PreferLocal_ExecutionContextFlowed(bool pref asyncLocal.Value = 0; Assert.Equal(42, await tcs.Task); } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void UnsafeQueueUserWorkItem_IThreadPoolWorkItem_Invalid_Throws(bool preferLocal) + { + AssertExtensions.Throws("callBack", () => ThreadPool.UnsafeQueueUserWorkItem(null, preferLocal)); + AssertExtensions.Throws("callBack", () => ThreadPool.UnsafeQueueUserWorkItem(new InvalidWorkItemAndTask(() => { }), preferLocal)); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UnsafeQueueUserWorkItem_IThreadPoolWorkItem_ManyIndividualItems_AllInvoked(bool preferLocal) + { + TaskCompletionSource[] tasks = Enumerable.Range(0, 100).Select(_ => new TaskCompletionSource()).ToArray(); + for (int i = 0; i < tasks.Length; i++) + { + int localI = i; + ThreadPool.UnsafeQueueUserWorkItem(new SimpleWorkItem(() => + { + tasks[localI].TrySetResult(true); + }), preferLocal); + } + await Task.WhenAll(tasks.Select(t => t.Task)); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UnsafeQueueUserWorkItem_IThreadPoolWorkItem_SameObjectReused_AllInvoked(bool preferLocal) + { + const int Iters = 100; + int remaining = Iters; + var tcs = new TaskCompletionSource(); + var workItem = new SimpleWorkItem(() => + { + if (Interlocked.Decrement(ref remaining) == 0) + { + tcs.TrySetResult(true); + } + }); + for (int i = 0; i < Iters; i++) + { + ThreadPool.UnsafeQueueUserWorkItem(workItem, preferLocal); + } + await tcs.Task; + Assert.Equal(0, remaining); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UnsafeQueueUserWorkItem_IThreadPoolWorkItem_ExecutionContextNotFlowed(bool preferLocal) + { + var al = new AsyncLocal { Value = 42 }; + var tcs = new TaskCompletionSource(); + ThreadPool.UnsafeQueueUserWorkItem(new SimpleWorkItem(() => + { + Assert.Equal(0, al.Value); + tcs.TrySetResult(true); + }), preferLocal); + await tcs.Task; + Assert.Equal(42, al.Value); + } + + private sealed class SimpleWorkItem : IThreadPoolWorkItem + { + private readonly Action _action; + public SimpleWorkItem(Action action) => _action = action; + public void Execute() => _action(); + } + + private sealed class InvalidWorkItemAndTask : Task, IThreadPoolWorkItem + { + public InvalidWorkItemAndTask(Action action) : base(action) { } + public void Execute() { } + } } }