Skip to content
Prev Previous commit
Next Next commit
Respond to feedback and add more tests.
This better integrates with the RemoteExecutor component as well,
by hooking up the service process and fetching its handle.

This gives us the correct logging and exitcode handling from
RemoteExecutor.
  • Loading branch information
ericstj committed Mar 30, 2023
commit 65bf8154183f2e5c88e2d8b75db52dd58a9db80b
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ internal struct SERVICE_STATUS_PROCESS
[return: MarshalAs(UnmanagedType.Bool)]
private static unsafe partial bool QueryServiceStatusEx(SafeServiceHandle serviceHandle, int InfoLevel, SERVICE_STATUS_PROCESS* pStatus, int cbBufSize, out int pcbBytesNeeded);

internal static unsafe bool QueryServiceStatusEx(SafeServiceHandle serviceHandle, SERVICE_STATUS_PROCESS* pStatus) => QueryServiceStatusEx(serviceHandle, SC_STATUS_PROCESS_INFO, pStatus, sizeof(SERVICE_STATUS_PROCESS), out int unused);
internal static unsafe bool QueryServiceStatusEx(SafeServiceHandle serviceHandle, SERVICE_STATUS_PROCESS* pStatus) => QueryServiceStatusEx(serviceHandle, SC_STATUS_PROCESS_INFO, pStatus, sizeof(SERVICE_STATUS_PROCESS), out _);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ namespace Microsoft.Extensions.Hosting.WindowsServices
public class WindowsServiceLifetime : ServiceBase, IHostLifetime
{
private readonly TaskCompletionSource<object?> _delayStart = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly TaskCompletionSource<object?> _serviceStopped = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly TaskCompletionSource<object?> _serviceDispatcherStopped = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly ManualResetEventSlim _delayStop = new ManualResetEventSlim();
private readonly HostOptions _hostOptions;
private bool _isStopped;
private bool _serviceStopped;

/// <summary>
/// Initializes a new <see cref="WindowsServiceLifetime"/> instance.
Expand Down Expand Up @@ -89,23 +89,24 @@ private void Run()
{
Run(this); // This blocks until the service is stopped.
_delayStart.TrySetException(new InvalidOperationException("Stopped without starting"));
_serviceDispatcherStopped.TrySetResult(null);
}
catch (Exception ex)
{
_delayStart.TrySetException(ex);
_serviceDispatcherStopped.TrySetException(ex);
}
_serviceStopped.TrySetResult(null);
}

public Task StopAsync(CancellationToken cancellationToken)
public async Task StopAsync(CancellationToken cancellationToken)
{
if (!_isStopped)
if (!_serviceStopped)
{
Task.Run(Stop, CancellationToken.None);
await Task.Run(Stop, CancellationToken.None).ConfigureAwait(false);
}

// When the underlying service is stopped this will cause the ServiceBase.Run method to complete and return, which completes _serviceStopped.
return _serviceStopped.Task;
await _serviceDispatcherStopped.Task.ConfigureAwait(false);
}

// Called by base.Run when the service is ready to start.
Expand All @@ -122,7 +123,7 @@ protected override void OnStart(string[] args)
/// <remarks>This might be called multiple times by service Stop, ApplicationStopping, and StopAsync. That's okay because StopApplication uses a CancellationTokenSource and prevents any recursion.</remarks>
protected override void OnStop()
{
_isStopped = true;
_serviceStopped = true;
ApplicationLifetime.StopApplication();
// Wait for the host to shutdown before marking service as stopped.
_delayStop.Wait(_hostOptions.ShutdownTimeout);
Expand All @@ -134,7 +135,7 @@ protected override void OnStop()
/// </summary>
protected override void OnShutdown()
{
_isStopped = true;
_serviceStopped = true;
ApplicationLifetime.StopApplication();
// Wait for the host to shutdown before marking service as stopped.
_delayStop.Wait(_hostOptions.ShutdownTimeout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public void DefaultsToOffOutsideOfService()
[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))]
public void CanCreateService()
{
using var serviceTester = WindowsServiceTester.Create(nameof(CanCreateService), () =>
using var serviceTester = WindowsServiceTester.Create(() =>
{
using IHost host = new HostBuilder()
.UseWindowsService()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting.Internal;
using Microsoft.Extensions.Hosting.WindowsServices;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Xunit;

Expand All @@ -18,30 +20,30 @@ namespace Microsoft.Extensions.Hosting
public class WindowsServiceLifetimeTests
{
[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))]
public void ServiceSequenceIsCorrect()
public void ServiceStops()
{
using var serviceTester = WindowsServiceTester.Create(nameof(ServiceSequenceIsCorrect), () =>
using var serviceTester = WindowsServiceTester.Create(() =>
{
SimpleServiceLogger.InitializeForTestCase(nameof(ServiceSequenceIsCorrect));
using IHost host = new HostBuilder()
.ConfigureServices(services =>
{
services.AddHostedService<SimpleBackgroundService>();
services.AddSingleton<IHostLifetime, SimpleWindowsServiceLifetime>();
})
.Build();
var applicationLifetime = new ApplicationLifetime(NullLogger<ApplicationLifetime>.Instance);
using var lifetime = new WindowsServiceLifetime(
new HostingEnvironment(),
applicationLifetime,
NullLoggerFactory.Instance,
new OptionsWrapper<HostOptions>(new HostOptions()));

var applicationLifetime = host.Services.GetRequiredService<IHostApplicationLifetime>();
applicationLifetime.ApplicationStarted.Register(() => SimpleServiceLogger.Log($"lifetime started"));
applicationLifetime.ApplicationStopping.Register(() => SimpleServiceLogger.Log($"lifetime stopping"));
applicationLifetime.ApplicationStopped.Register(() => SimpleServiceLogger.Log($"lifetime stopped"));
lifetime.WaitForStartAsync(CancellationToken.None).GetAwaiter().GetResult();

// would normally occur here, but WindowsServiceLifetime does not depend on it.
// applicationLifetime.NotifyStarted();

// will be signaled by WindowsServiceLifetime when SCM stops the service.
applicationLifetime.ApplicationStopping.WaitHandle.WaitOne();

SimpleServiceLogger.Log("host.Run()");
host.Run();
SimpleServiceLogger.Log("host.Run() complete");
});
// required by WindowsServiceLifetime to identify that app has stopped.
applicationLifetime.NotifyStopped();

SimpleServiceLogger.DeleteLog(nameof(ServiceSequenceIsCorrect));
lifetime.StopAsync(CancellationToken.None).GetAwaiter().GetResult();
});

serviceTester.Start();
serviceTester.WaitForStatus(ServiceControllerStatus.Running);
Expand All @@ -56,8 +58,156 @@ public void ServiceSequenceIsCorrect()

var status = serviceTester.QueryServiceStatus();
Assert.Equal(0, status.win32ExitCode);
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))]
[SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework, ".NET Framework is missing the fix from https://github.com/dotnet/corefx/commit/3e68d791066ad0fdc6e0b81828afbd9df00dd7f8")]
public void ExceptionOnStartIsPropagated()
{
using var serviceTester = WindowsServiceTester.Create(() =>
{
using (var lifetime = ThrowingWindowsServiceLifetime.Create())
{
lifetime.ThrowOnStart = new Exception("Should be thrown");
Assert.Equal(lifetime.ThrowOnStart,
Assert.Throws<Exception>( () =>
lifetime.WaitForStartAsync(CancellationToken.None).GetAwaiter().GetResult() ));
}
});

serviceTester.Start();

serviceTester.WaitForStatus(ServiceControllerStatus.Stopped);
var status = serviceTester.QueryServiceStatus();
Assert.Equal(1064, status.win32ExitCode);
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))]
public void ExceptionOnStopIsPropagated()
{
using var serviceTester = WindowsServiceTester.Create(() =>
{
using (var lifetime = ThrowingWindowsServiceLifetime.Create())
{
lifetime.WaitForStartAsync(CancellationToken.None).GetAwaiter().GetResult();

lifetime.ThrowOnStop = new Exception("Should be thrown");
lifetime.ApplicationLifetime.NotifyStopped();
Assert.Equal(lifetime.ThrowOnStop,
Assert.Throws<Exception>( () =>
lifetime.StopAsync(CancellationToken.None).GetAwaiter().GetResult() ));
}
});

serviceTester.Start();

// service will proceed to stopped without any error
serviceTester.WaitForStatus(ServiceControllerStatus.Stopped);
var status = serviceTester.QueryServiceStatus();
Assert.Equal(1067, status.win32ExitCode);
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))]
public void ServiceCanStopItself()
{
using (var serviceTester = WindowsServiceTester.Create(() =>
{
FileLogger.InitializeForTestCase(nameof(ServiceCanStopItself));
using IHost host = new HostBuilder()
.ConfigureServices(services =>
{
services.AddHostedService<LoggingBackgroundService>();
services.AddSingleton<IHostLifetime, LoggingWindowsServiceLifetime>();
})
.Build();

var applicationLifetime = host.Services.GetRequiredService<IHostApplicationLifetime>();
applicationLifetime.ApplicationStarted.Register(() => FileLogger.Log($"lifetime started"));
applicationLifetime.ApplicationStopping.Register(() => FileLogger.Log($"lifetime stopping"));
applicationLifetime.ApplicationStopped.Register(() => FileLogger.Log($"lifetime stopped"));

FileLogger.Log("host.Start()");
host.Start();

var logText = SimpleServiceLogger.ReadLog(nameof(ServiceSequenceIsCorrect));
FileLogger.Log("host.Stop()");
host.StopAsync().GetAwaiter().GetResult();
FileLogger.Log("host.Stop() complete");
}))
{
FileLogger.DeleteLog(nameof(ServiceCanStopItself));

// service should start cleanly
serviceTester.Start();

// service will proceed to stopped without any error
serviceTester.WaitForStatus(ServiceControllerStatus.Stopped);

var status = serviceTester.QueryServiceStatus();
Assert.Equal(0, status.win32ExitCode);

}

var logText = FileLogger.ReadLog(nameof(ServiceCanStopItself));
Assert.Equal("""
host.Start()
WindowsServiceLifetime.OnStart
BackgroundService.StartAsync
lifetime started
host.Stop()
lifetime stopping
BackgroundService.StopAsync
lifetime stopped
WindowsServiceLifetime.OnStop
host.Stop() complete

""", logText);
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))]
public void ServiceSequenceIsCorrect()
{
using (var serviceTester = WindowsServiceTester.Create(() =>
{
FileLogger.InitializeForTestCase(nameof(ServiceSequenceIsCorrect));
using IHost host = new HostBuilder()
.ConfigureServices(services =>
{
services.AddHostedService<LoggingBackgroundService>();
services.AddSingleton<IHostLifetime, LoggingWindowsServiceLifetime>();
})
.Build();

var applicationLifetime = host.Services.GetRequiredService<IHostApplicationLifetime>();
applicationLifetime.ApplicationStarted.Register(() => FileLogger.Log($"lifetime started"));
applicationLifetime.ApplicationStopping.Register(() => FileLogger.Log($"lifetime stopping"));
applicationLifetime.ApplicationStopped.Register(() => FileLogger.Log($"lifetime stopped"));

FileLogger.Log("host.Run()");
host.Run();
FileLogger.Log("host.Run() complete");
}))
{

FileLogger.DeleteLog(nameof(ServiceSequenceIsCorrect));

serviceTester.Start();
serviceTester.WaitForStatus(ServiceControllerStatus.Running);

var statusEx = serviceTester.QueryServiceStatusEx();
var serviceProcess = Process.GetProcessById(statusEx.dwProcessId);

// Give a chance for all asynchronous "started" events to be raised, these happen after the service status changes to started
Thread.Sleep(1000);

serviceTester.Stop();
serviceTester.WaitForStatus(ServiceControllerStatus.Stopped);

var status = serviceTester.QueryServiceStatus();
Assert.Equal(0, status.win32ExitCode);

}

var logText = FileLogger.ReadLog(nameof(ServiceSequenceIsCorrect));
Assert.Equal("""
host.Run()
WindowsServiceLifetime.OnStart
Expand All @@ -73,35 +223,77 @@ lifetime stopped

}

public class SimpleWindowsServiceLifetime : WindowsServiceLifetime
public class LoggingWindowsServiceLifetime : WindowsServiceLifetime
{
public SimpleWindowsServiceLifetime(IHostEnvironment environment, IHostApplicationLifetime applicationLifetime, ILoggerFactory loggerFactory, IOptions<HostOptions> optionsAccessor) :
public LoggingWindowsServiceLifetime(IHostEnvironment environment, IHostApplicationLifetime applicationLifetime, ILoggerFactory loggerFactory, IOptions<HostOptions> optionsAccessor) :
base(environment, applicationLifetime, loggerFactory, optionsAccessor)
{ }

protected override void OnStart(string[] args)
{
SimpleServiceLogger.Log("WindowsServiceLifetime.OnStart");
FileLogger.Log("WindowsServiceLifetime.OnStart");
base.OnStart(args);
}

protected override void OnStop()
{
SimpleServiceLogger.Log("WindowsServiceLifetime.OnStop");
FileLogger.Log("WindowsServiceLifetime.OnStop");
base.OnStop();
}
}

public class ThrowingWindowsServiceLifetime : WindowsServiceLifetime
{
public static ThrowingWindowsServiceLifetime Create(Exception throwOnStart = null, Exception throwOnStop = null) =>
new ThrowingWindowsServiceLifetime(
new HostingEnvironment(),
new ApplicationLifetime(NullLogger<ApplicationLifetime>.Instance),
NullLoggerFactory.Instance,
new OptionsWrapper<HostOptions>(new HostOptions()))
{
ThrowOnStart = throwOnStart,
ThrowOnStop = throwOnStop
};

public ThrowingWindowsServiceLifetime(IHostEnvironment environment, ApplicationLifetime applicationLifetime, ILoggerFactory loggerFactory, IOptions<HostOptions> optionsAccessor) :
base(environment, applicationLifetime, loggerFactory, optionsAccessor)
{
ApplicationLifetime = applicationLifetime;
}

public ApplicationLifetime ApplicationLifetime { get; }

public Exception ThrowOnStart { get; set; }
protected override void OnStart(string[] args)
{
if (ThrowOnStart != null)
{
throw ThrowOnStart;
}
base.OnStart(args);
}

public Exception ThrowOnStop { get; set; }
protected override void OnStop()
{
if (ThrowOnStop != null)
{
throw ThrowOnStop;
}
base.OnStop();
}
}

public class SimpleBackgroundService : BackgroundService
public class LoggingBackgroundService : BackgroundService
{
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
protected override async Task ExecuteAsync(CancellationToken stoppingToken) => SimpleServiceLogger.Log("BackgroundService.ExecuteAsync");
public override async Task StartAsync(CancellationToken stoppingToken) => SimpleServiceLogger.Log("BackgroundService.StartAsync");
public override async Task StopAsync(CancellationToken stoppingToken) => SimpleServiceLogger.Log("BackgroundService.StopAsync");
protected override async Task ExecuteAsync(CancellationToken stoppingToken) => FileLogger.Log("BackgroundService.ExecuteAsync");
public override async Task StartAsync(CancellationToken stoppingToken) => FileLogger.Log("BackgroundService.StartAsync");
public override async Task StopAsync(CancellationToken stoppingToken) => FileLogger.Log("BackgroundService.StopAsync");
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
}

static class SimpleServiceLogger
static class FileLogger
{
static string _fileName;

Expand Down
Loading