diff --git a/eng/Version.Details.xml b/eng/Version.Details.xml index 28d522cbfe01..cbb97411e108 100644 --- a/eng/Version.Details.xml +++ b/eng/Version.Details.xml @@ -9,286 +9,286 @@ --> - + https://github.com/dotnet/efcore - 0faf528551a21deeeb761fe80fec125661d2019b + 382a6cb8f738440bbb7544e69e2ecb0fce695bd5 - + https://github.com/dotnet/efcore - 0faf528551a21deeeb761fe80fec125661d2019b + 382a6cb8f738440bbb7544e69e2ecb0fce695bd5 - + https://github.com/dotnet/efcore - 0faf528551a21deeeb761fe80fec125661d2019b + 382a6cb8f738440bbb7544e69e2ecb0fce695bd5 - + https://github.com/dotnet/efcore - 0faf528551a21deeeb761fe80fec125661d2019b + 382a6cb8f738440bbb7544e69e2ecb0fce695bd5 - + https://github.com/dotnet/efcore - 0faf528551a21deeeb761fe80fec125661d2019b + 382a6cb8f738440bbb7544e69e2ecb0fce695bd5 - + https://github.com/dotnet/efcore - 0faf528551a21deeeb761fe80fec125661d2019b + 382a6cb8f738440bbb7544e69e2ecb0fce695bd5 - + https://github.com/dotnet/efcore - 0faf528551a21deeeb761fe80fec125661d2019b + 382a6cb8f738440bbb7544e69e2ecb0fce695bd5 - + https://github.com/dotnet/efcore - 0faf528551a21deeeb761fe80fec125661d2019b + 382a6cb8f738440bbb7544e69e2ecb0fce695bd5 - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed https://github.com/dotnet/source-build-externals e53b62ccc6a887987efdb820334594b674f6071d - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed https://github.com/dotnet/xdt @@ -298,9 +298,9 @@ - + https://github.com/dotnet/runtime - b83539cbf22d4197f4b0004101b42abcad56316c + 8bb1983a4a231bdbb0947c65643d7edc48c8e1ed https://github.com/dotnet/arcade diff --git a/eng/Versions.props b/eng/Versions.props index e79c79a32854..e44cd90679a7 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -63,77 +63,77 @@ --> - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 - 7.0.0-rc.2.22458.4 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 + 7.0.0-rc.2.22459.6 - 7.0.0-rc.2.22458.4 + 7.0.0-rc.2.22459.6 - 7.0.0-rc.2.22458.14 - 7.0.0-rc.2.22458.14 - 7.0.0-rc.2.22458.14 - 7.0.0-rc.2.22458.14 - 7.0.0-rc.2.22458.14 - 7.0.0-rc.2.22458.14 - 7.0.0-rc.2.22458.14 - 7.0.0-rc.2.22458.14 + 7.0.0-rc.2.22459.14 + 7.0.0-rc.2.22459.14 + 7.0.0-rc.2.22459.14 + 7.0.0-rc.2.22459.14 + 7.0.0-rc.2.22459.14 + 7.0.0-rc.2.22459.14 + 7.0.0-rc.2.22459.14 + 7.0.0-rc.2.22459.14 7.0.0-beta.22423.1 7.0.0-beta.22423.1 diff --git a/src/Components/Web/src/Routing/NavigationLock.cs b/src/Components/Web/src/Routing/NavigationLock.cs index 2dd9aac504f6..e8b5566531de 100644 --- a/src/Components/Web/src/Routing/NavigationLock.cs +++ b/src/Components/Web/src/Routing/NavigationLock.cs @@ -9,13 +9,16 @@ namespace Microsoft.AspNetCore.Components.Routing; /// /// A component that can be used to intercept navigation events. /// -public sealed class NavigationLock : IComponent, IAsyncDisposable +public sealed class NavigationLock : IComponent, IHandleAfterRender, IAsyncDisposable { private readonly string _id = Guid.NewGuid().ToString("D", CultureInfo.InvariantCulture); + private RenderHandle _renderHandle; private IDisposable? _locationChangingRegistration; + private bool _hasLocationChangingHandler; + private bool _confirmExternalNavigation; - private bool HasOnBeforeInternalNavigationCallback => OnBeforeInternalNavigation.HasDelegate; + private bool HasLocationChangingHandler => OnBeforeInternalNavigation.HasDelegate; [Inject] private IJSRuntime JSRuntime { get; set; } = default!; @@ -38,28 +41,51 @@ public sealed class NavigationLock : IComponent, IAsyncDisposable void IComponent.Attach(RenderHandle renderHandle) { + _renderHandle = renderHandle; } - async Task IComponent.SetParametersAsync(ParameterView parameters) + Task IComponent.SetParametersAsync(ParameterView parameters) { - var lastHasOnBeforeInternalNavigationCallback = HasOnBeforeInternalNavigationCallback; - var lastConfirmExternalNavigation = ConfirmExternalNavigation; + foreach (var parameter in parameters) + { + if (parameter.Name.Equals(nameof(OnBeforeInternalNavigation), StringComparison.OrdinalIgnoreCase)) + { + OnBeforeInternalNavigation = (EventCallback)parameter.Value; + } + else if (parameter.Name.Equals(nameof(ConfirmExternalNavigation), StringComparison.OrdinalIgnoreCase)) + { + ConfirmExternalNavigation = (bool)parameter.Value; + } + else + { + throw new ArgumentException($"The component '{nameof(NavigationLock)}' does not accept a parameter with the name '{parameter.Name}'."); + } + } - parameters.SetParameterProperties(this); + if (_hasLocationChangingHandler != HasLocationChangingHandler || + _confirmExternalNavigation != ConfirmExternalNavigation) + { + _renderHandle.Render(static builder => { }); + } + + return Task.CompletedTask; + } - var hasOnBeforeInternalNavigationCallback = HasOnBeforeInternalNavigationCallback; - if (hasOnBeforeInternalNavigationCallback != lastHasOnBeforeInternalNavigationCallback) + async Task IHandleAfterRender.OnAfterRenderAsync() + { + if (_hasLocationChangingHandler != HasLocationChangingHandler) { + _hasLocationChangingHandler = HasLocationChangingHandler; _locationChangingRegistration?.Dispose(); - _locationChangingRegistration = hasOnBeforeInternalNavigationCallback + _locationChangingRegistration = _hasLocationChangingHandler ? NavigationManager.RegisterLocationChangingHandler(OnLocationChanging) : null; } - var confirmExternalNavigation = ConfirmExternalNavigation; - if (confirmExternalNavigation != lastConfirmExternalNavigation) + if (_confirmExternalNavigation != ConfirmExternalNavigation) { - if (confirmExternalNavigation) + _confirmExternalNavigation = ConfirmExternalNavigation; + if (_confirmExternalNavigation) { await JSRuntime.InvokeVoidAsync(NavigationLockInterop.EnableNavigationPrompt, _id); } @@ -70,7 +96,7 @@ async Task IComponent.SetParametersAsync(ParameterView parameters) } } - async ValueTask OnLocationChanging(LocationChangingContext context) + private async ValueTask OnLocationChanging(LocationChangingContext context) { await OnBeforeInternalNavigation.InvokeAsync(context); } @@ -78,6 +104,10 @@ async ValueTask OnLocationChanging(LocationChangingContext context) async ValueTask IAsyncDisposable.DisposeAsync() { _locationChangingRegistration?.Dispose(); - await JSRuntime.InvokeVoidAsync(NavigationLockInterop.DisableNavigationPrompt, _id); + + if (_confirmExternalNavigation) + { + await JSRuntime.InvokeVoidAsync(NavigationLockInterop.DisableNavigationPrompt, _id); + } } } diff --git a/src/Components/WebAssembly/WebAssembly.Authentication/src/RemoteAuthenticatorViewCore.cs b/src/Components/WebAssembly/WebAssembly.Authentication/src/RemoteAuthenticatorViewCore.cs index 95c178a45698..1f2d5788e8df 100644 --- a/src/Components/WebAssembly/WebAssembly.Authentication/src/RemoteAuthenticatorViewCore.cs +++ b/src/Components/WebAssembly/WebAssembly.Authentication/src/RemoteAuthenticatorViewCore.cs @@ -17,6 +17,7 @@ public partial class RemoteAuthenticatorViewCore<[DynamicallyAccessedMembers(Jso { private RemoteAuthenticationApplicationPathsOptions _applicationPaths; private string _action; + private string _lastHandledAction; private InteractiveRequestOptions _cachedRequest; private static readonly NavigationOptions AuthenticationNavigationOptions = @@ -152,6 +153,13 @@ protected override void BuildRenderTree(RenderTreeBuilder builder) /// protected override async Task OnParametersSetAsync() { + if (_lastHandledAction == Action) + { + // Avoid processing the same action more than once. + return; + } + + _lastHandledAction = Action; Log.ProcessingAuthenticatorAction(Logger, Action); switch (Action) { diff --git a/src/Components/WebAssembly/WebAssembly.Authentication/test/RemoteAuthenticatorCoreTests.cs b/src/Components/WebAssembly/WebAssembly.Authentication/test/RemoteAuthenticatorCoreTests.cs index e6a7e7a2968d..68ce23ac5ffe 100644 --- a/src/Components/WebAssembly/WebAssembly.Authentication/test/RemoteAuthenticatorCoreTests.cs +++ b/src/Components/WebAssembly/WebAssembly.Authentication/test/RemoteAuthenticatorCoreTests.cs @@ -236,6 +236,55 @@ public async Task AuthenticationManager_LoginCallback_NavigatesToLoginFailureOnE ((TestNavigationManager)remoteAuthenticator.Navigation).HistoryEntryState); } + [Fact] + public async Task AuthenticationManager_Callbacks_OnlyExecuteOncePerAction() + { + // Arrange + var (remoteAuthenticator, renderer, authServiceMock) = CreateAuthenticationManager( + "https://www.example.com/base/authentication/login-callback?code=1234"); + + authServiceMock.CompleteSignInCallback = s => Task.FromResult(new RemoteAuthenticationResult() + { + Status = RemoteAuthenticationStatus.Success, + }); + + authServiceMock.CompleteSignOutCallback = s => Task.FromResult(new RemoteAuthenticationResult() + { + Status = RemoteAuthenticationStatus.Success, + }); + + var logInCallbackInvocationCount = 0; + var logOutCallbackInvocationCount = 0; + + var parameterDictionary = new Dictionary + { + [_action] = RemoteAuthenticationActions.LogInCallback, + [_onLogInSucceded] = new EventCallbackFactory().Create( + remoteAuthenticator, + (state) => logInCallbackInvocationCount++), + [_onLogOutSucceeded] = new EventCallbackFactory().Create( + remoteAuthenticator, + (state) => logOutCallbackInvocationCount++) + }; + + var initialParameters = ParameterView.FromDictionary(parameterDictionary); + + parameterDictionary[_action] = RemoteAuthenticationActions.LogOutCallback; + + var finalParameters = ParameterView.FromDictionary(parameterDictionary); + + // Act + await renderer.Dispatcher.InvokeAsync(() => remoteAuthenticator.SetParametersAsync(initialParameters)); + await renderer.Dispatcher.InvokeAsync(() => remoteAuthenticator.SetParametersAsync(initialParameters)); + + await renderer.Dispatcher.InvokeAsync(() => remoteAuthenticator.SetParametersAsync(finalParameters)); + await renderer.Dispatcher.InvokeAsync(() => remoteAuthenticator.SetParametersAsync(finalParameters)); + + // Assert + Assert.Equal(1, logInCallbackInvocationCount); + Assert.Equal(1, logOutCallbackInvocationCount); + } + [Fact] public async Task AuthenticationManager_Logout_NavigatesToReturnUrlOnSuccess() { diff --git a/src/Components/test/E2ETest/ServerExecutionTests/NavigationLockPrerenderingTest.cs b/src/Components/test/E2ETest/ServerExecutionTests/NavigationLockPrerenderingTest.cs new file mode 100644 index 000000000000..e59585d80533 --- /dev/null +++ b/src/Components/test/E2ETest/ServerExecutionTests/NavigationLockPrerenderingTest.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Components.E2ETest.Infrastructure; +using Microsoft.AspNetCore.Components.E2ETest.Infrastructure.ServerFixtures; +using Microsoft.AspNetCore.E2ETesting; +using OpenQA.Selenium; +using TestServer; +using Xunit.Abstractions; + +namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests; + +public class NavigationLockPrerenderingTest : ServerTestBase> +{ + public NavigationLockPrerenderingTest( + BrowserFixture browserFixture, + BasicTestAppServerSiteFixture serverFixture, + ITestOutputHelper output) + : base(browserFixture, serverFixture, output) + { + } + + [Fact] + public void NavigationIsLockedAfterPrerendering() + { + Navigate("/locked-navigation"); + + // Assert that the component rendered successfully + Browser.Equal("Prevented navigations: 0", () => Browser.FindElement(By.Id("num-prevented-navigations")).Text); + + BeginInteractivity(); + + // Assert that internal navigations are blocked + Browser.Click(By.Id("internal-navigation-link")); + Browser.Equal("Prevented navigations: 1", () => Browser.FindElement(By.Id("num-prevented-navigations")).Text); + + // Assert that external navigations are blocked + Browser.Navigate().GoToUrl("about:blank"); + Browser.SwitchTo().Alert().Dismiss(); + Browser.Equal("Prevented navigations: 1", () => Browser.FindElement(By.Id("num-prevented-navigations")).Text); + } + + private void BeginInteractivity() + { + Browser.Exists(By.Id("load-boot-script")).Click(); + + var javascript = (IJavaScriptExecutor)Browser; + Browser.True(() => (bool)javascript.ExecuteScript("return window['__aspnetcore__testing__blazor__started__'] === true;")); + } +} diff --git a/src/Components/test/testassets/BasicTestApp/RouterTest/LockNavigationOnPageLoad.razor b/src/Components/test/testassets/BasicTestApp/RouterTest/LockNavigationOnPageLoad.razor new file mode 100644 index 000000000000..14f901b38685 --- /dev/null +++ b/src/Components/test/testassets/BasicTestApp/RouterTest/LockNavigationOnPageLoad.razor @@ -0,0 +1,28 @@ +@using Microsoft.AspNetCore.Components.Routing + +@inject INavigationInterception NavigationInterception + +Internal navigation + +Prevented navigations: @_numPreventedNavigations + + + +@code { + private int _numPreventedNavigations = 0; + + protected override async Task OnAfterRenderAsync(bool firstRender) + { + if (firstRender) + { + await NavigationInterception.EnableNavigationInterceptionAsync(); + } + } + + private Task HandleBeforeInternalNavigationAsync(LocationChangingContext context) + { + _numPreventedNavigations++; + context.PreventNavigation(); + return Task.CompletedTask; + } +} diff --git a/src/Components/test/testassets/TestServer/LockedNavigationStartup.cs b/src/Components/test/testassets/TestServer/LockedNavigationStartup.cs new file mode 100644 index 000000000000..13a3c95fc053 --- /dev/null +++ b/src/Components/test/testassets/TestServer/LockedNavigationStartup.cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Globalization; +using Microsoft.AspNetCore.Authentication.Cookies; + +namespace TestServer; + +public class LockedNavigationStartup +{ + // This method gets called by the runtime. Use this method to add services to the container. + public void ConfigureServices(IServiceCollection services) + { + services.AddMvc(); + services.AddServerSideBlazor(); + services.AddAuthentication(CookieAuthenticationDefaults.AuthenticationScheme).AddCookie(); + } + + // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. + public void Configure(IApplicationBuilder app, IWebHostEnvironment env) + { + var enUs = new CultureInfo("en-US"); + CultureInfo.DefaultThreadCurrentCulture = enUs; + CultureInfo.DefaultThreadCurrentUICulture = enUs; + + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + } + + app.Map("/locked-navigation", app => + { + app.UseStaticFiles(); + + app.UseAuthentication(); + + app.UseRouting(); + app.UseEndpoints(endpoints => + { + endpoints.MapRazorPages(); + endpoints.MapFallbackToPage("/LockedNavigationHost"); + endpoints.MapBlazorHub(); + }); + }); + } +} diff --git a/src/Components/test/testassets/TestServer/Pages/LockedNavigationHost.cshtml b/src/Components/test/testassets/TestServer/Pages/LockedNavigationHost.cshtml new file mode 100644 index 000000000000..5fe918596090 --- /dev/null +++ b/src/Components/test/testassets/TestServer/Pages/LockedNavigationHost.cshtml @@ -0,0 +1,33 @@ +@page "/locked-navigation" +@using BasicTestApp.RouterTest + + + + + Locked navigation + + + + + + @* + So that E2E tests can make assertions about both the prerendered and + interactive states, we only load the .js file when told to. + *@ +
+ + + + + + + + diff --git a/src/Components/test/testassets/TestServer/Program.cs b/src/Components/test/testassets/TestServer/Program.cs index 73b7875fbfdf..80c07b6da0a9 100644 --- a/src/Components/test/testassets/TestServer/Program.cs +++ b/src/Components/test/testassets/TestServer/Program.cs @@ -20,6 +20,7 @@ public static async Task Main(string[] args) ["CORS (WASM)"] = (BuildWebHost(CreateAdditionalArgs(args)), "/subdir"), ["Prerendering (Server-side)"] = (BuildWebHost(CreateAdditionalArgs(args)), "/prerendered"), ["Deferred component content (Server-side)"] = (BuildWebHost(CreateAdditionalArgs(args)), "/deferred-component-content"), + ["Locked navigation (Server-side)"] = (BuildWebHost(CreateAdditionalArgs(args)), "/locked-navigation"), ["Client-side with fallback"] = (BuildWebHost(CreateAdditionalArgs(args)), "/fallback"), ["Multiple components (Server-side)"] = (BuildWebHost(CreateAdditionalArgs(args)), "/multiple-components"), ["Save state"] = (BuildWebHost(CreateAdditionalArgs(args)), "/save-state"), diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 3a27258667f7..30db3740bb81 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -645,6 +645,22 @@ private static Expression CreateArgument(ParameterInfo parameter, RequestDelegat throw new InvalidOperationException($"Encountered a parameter of type '{parameter.ParameterType}' without a name. Parameters must have a name."); } + if (parameter.ParameterType.IsByRef) + { + var attribute = "ref"; + + if (parameter.Attributes.HasFlag(ParameterAttributes.In)) + { + attribute = "in"; + } + else if (parameter.Attributes.HasFlag(ParameterAttributes.Out)) + { + attribute = "out"; + } + + throw new NotSupportedException($"The by reference parameter '{attribute} {TypeNameHelper.GetTypeDisplayName(parameter.ParameterType, fullName: false)} {parameter.Name}' is not supported."); + } + var parameterCustomAttributes = parameter.GetCustomAttributes(); if (parameterCustomAttributes.OfType().FirstOrDefault() is { } routeAttribute) diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index ebe5a420c9a5..8960560ae5d6 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -2634,6 +2634,24 @@ public static object[][] FromServiceActions } } + [Fact] + public void BuildRequestDelegateThrowsNotSupportedExceptionForByRefParameters() + { + void OutMethod(out string foo) { foo = ""; } + void InMethod(in string foo) { } + void RefMethod(ref string foo) { } + + var outParamException = Assert.Throws(() => RequestDelegateFactory.Create(OutMethod)); + var inParamException = Assert.Throws(() => RequestDelegateFactory.Create(InMethod)); + var refParamException = Assert.Throws(() => RequestDelegateFactory.Create(RefMethod)); + + var typeName = typeof(string).MakeByRefType().Name; + + Assert.Equal($"The by reference parameter 'out {typeName} foo' is not supported.", outParamException.Message); + Assert.Equal($"The by reference parameter 'in {typeName} foo' is not supported.", inParamException.Message); + Assert.Equal($"The by reference parameter 'ref {typeName} foo' is not supported.", refParamException.Message); + } + [Theory] [MemberData(nameof(ImplicitFromServiceActions))] public async Task RequestDelegateRequiresServiceForAllImplicitFromServiceParameters(Delegate action) diff --git a/src/Middleware/OutputCaching/src/Memory/MemoryOutputCacheStore.cs b/src/Middleware/OutputCaching/src/Memory/MemoryOutputCacheStore.cs index 3f1df73c7b80..a75546b6793f 100644 --- a/src/Middleware/OutputCaching/src/Memory/MemoryOutputCacheStore.cs +++ b/src/Middleware/OutputCaching/src/Memory/MemoryOutputCacheStore.cs @@ -1,23 +1,27 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using Microsoft.Extensions.Caching.Memory; namespace Microsoft.AspNetCore.OutputCaching.Memory; internal sealed class MemoryOutputCacheStore : IOutputCacheStore { - private readonly IMemoryCache _cache; + private readonly MemoryCache _cache; private readonly Dictionary> _taggedEntries = new(); private readonly object _tagsLock = new(); - internal MemoryOutputCacheStore(IMemoryCache cache) + internal MemoryOutputCacheStore(MemoryCache cache) { ArgumentNullException.ThrowIfNull(cache); _cache = cache; } + // For testing + internal Dictionary> TaggedEntries => _taggedEntries; + public ValueTask EvictByTagAsync(string tag, CancellationToken cancellationToken) { ArgumentNullException.ThrowIfNull(tag); @@ -26,12 +30,29 @@ public ValueTask EvictByTagAsync(string tag, CancellationToken cancellationToken { if (_taggedEntries.TryGetValue(tag, out var keys)) { - foreach (var key in keys) + if (keys != null && keys.Count > 0) { - _cache.Remove(key); - } + // If MemoryCache changed to run eviction callbacks inline in Remove, iterating over keys could throw + // To prevent allocating a copy of the keys we check if the eviction callback ran, + // and if it did we restart the loop. - _taggedEntries.Remove(tag); + var i = keys.Count; + while (i > 0) + { + var oldCount = keys.Count; + foreach (var key in keys) + { + _cache.Remove(key); + i--; + if (oldCount != keys.Count) + { + // eviction callback ran inline, we need to restart the loop to avoid + // "collection modified while iterating" errors + break; + } + } + } + } } } @@ -62,35 +83,75 @@ public ValueTask SetAsync(string key, byte[] value, string[]? tags, TimeSpan val { foreach (var tag in tags) { + if (tag is null) + { + throw new ArgumentException(Resources.TagCannotBeNull); + } + if (!_taggedEntries.TryGetValue(tag, out var keys)) { keys = new HashSet(); _taggedEntries[tag] = keys; } + Debug.Assert(keys != null); + keys.Add(key); } - SetEntry(); + SetEntry(key, value, tags, validFor); } } else { - SetEntry(); + SetEntry(key, value, tags, validFor); } - void SetEntry() + return ValueTask.CompletedTask; + } + + void SetEntry(string key, byte[] value, string[]? tags, TimeSpan validFor) + { + Debug.Assert(key != null); + + var options = new MemoryCacheEntryOptions { - _cache.Set( - key, - value, - new MemoryCacheEntryOptions - { - AbsoluteExpirationRelativeToNow = validFor, - Size = value.Length - }); + AbsoluteExpirationRelativeToNow = validFor, + Size = value.Length + }; + + if (tags != null && tags.Length > 0) + { + // Remove cache keys from tag lists when the entry is evicted + options.RegisterPostEvictionCallback(RemoveFromTags, tags); } - return ValueTask.CompletedTask; + _cache.Set(key, value, options); + } + + void RemoveFromTags(object key, object? value, EvictionReason reason, object? state) + { + var tags = state as string[]; + + Debug.Assert(tags != null); + Debug.Assert(tags.Length > 0); + Debug.Assert(key is string); + + lock (_tagsLock) + { + foreach (var tag in tags) + { + if (_taggedEntries.TryGetValue(tag, out var tagged)) + { + tagged.Remove((string)key); + + // Remove the collection if there is no more keys in it + if (tagged.Count == 0) + { + _taggedEntries.Remove(tag); + } + } + } + } } } diff --git a/src/Middleware/OutputCaching/src/Resources.resx b/src/Middleware/OutputCaching/src/Resources.resx index 3a19868a7302..621b8c2fc42c 100644 --- a/src/Middleware/OutputCaching/src/Resources.resx +++ b/src/Middleware/OutputCaching/src/Resources.resx @@ -117,7 +117,7 @@ System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 - - The type '{0}' is not a valid output policy. + + A tag value cannot be null. \ No newline at end of file diff --git a/src/Middleware/OutputCaching/test/MemoryOutputCacheStoreTests.cs b/src/Middleware/OutputCaching/test/MemoryOutputCacheStoreTests.cs index e15f4515ca11..9441520b849f 100644 --- a/src/Middleware/OutputCaching/test/MemoryOutputCacheStoreTests.cs +++ b/src/Middleware/OutputCaching/test/MemoryOutputCacheStoreTests.cs @@ -71,7 +71,19 @@ public async Task EvictByTag_SingleTag_SingleEntry() await store.EvictByTagAsync("tag1", default); var result = await store.GetAsync(key, default); + HashSet tag1s; + + // Wait for the hashset to be removed as it happens on a separate thread + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + + while (store.TaggedEntries.TryGetValue("tag1", out tag1s) && !cts.IsCancellationRequested) + { + await Task.Yield(); + } + Assert.Null(result); + Assert.Null(tag1s); } [Fact] @@ -140,6 +152,62 @@ public async Task EvictByTag_MultipleTags_MultipleEntries() Assert.Null(result2); } + [Fact] + public async Task ExpiredEntries_AreRemovedFromTags() + { + var testClock = new TestMemoryOptionsClock { UtcNow = DateTimeOffset.UtcNow }; + var cache = new MemoryCache(new MemoryCacheOptions { SizeLimit = 1000, Clock = testClock, ExpirationScanFrequency = TimeSpan.FromMilliseconds(1) }); + var store = new MemoryOutputCacheStore(cache); + var value = "abc"u8.ToArray(); + + await store.SetAsync("a", value, new[] { "tag1" }, TimeSpan.FromMilliseconds(5), default); + await store.SetAsync("b", value, new[] { "tag1", "tag2" }, TimeSpan.FromMilliseconds(5), default); + await store.SetAsync("c", value, new[] { "tag2" }, TimeSpan.FromMilliseconds(20), default); + + testClock.Advance(TimeSpan.FromMilliseconds(10)); + + // Background expiration checks are triggered by misc cache activity. + _ = cache.Get("a"); + + var resulta = await store.GetAsync("a", default); + var resultb = await store.GetAsync("b", default); + var resultc = await store.GetAsync("c", default); + + Assert.Null(resulta); + Assert.Null(resultb); + Assert.NotNull(resultc); + + HashSet tag1s, tag2s; + + // Wait for the hashset to be removed as it happens on a separate thread + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + + while (store.TaggedEntries.TryGetValue("tag1", out tag1s) && !cts.IsCancellationRequested) + { + await Task.Yield(); + } + + while (store.TaggedEntries.TryGetValue("tag2", out tag2s) && tag2s.Count != 1 && !cts.IsCancellationRequested) + { + await Task.Yield(); + } + + Assert.Null(tag1s); + Assert.Single(tag2s); + } + + [Theory] + [InlineData(null)] + public async Task Store_Throws_OnInvalidTag(string tag) + { + var store = new MemoryOutputCacheStore(new MemoryCache(new MemoryCacheOptions())); + var value = "abc"u8.ToArray(); + var key = "abc"; + + await Assert.ThrowsAsync(async () => await store.SetAsync(key, value, new string[] { tag }, TimeSpan.FromMinutes(1), default)); + } + private class TestMemoryOptionsClock : Extensions.Internal.ISystemClock { public DateTimeOffset UtcNow { get; set; } diff --git a/src/Mvc/Mvc.Abstractions/src/ModelBinding/ModelMetadata.cs b/src/Mvc/Mvc.Abstractions/src/ModelBinding/ModelMetadata.cs index 90b32074102a..1f6e8042341a 100644 --- a/src/Mvc/Mvc.Abstractions/src/ModelBinding/ModelMetadata.cs +++ b/src/Mvc/Mvc.Abstractions/src/ModelBinding/ModelMetadata.cs @@ -545,6 +545,13 @@ internal void ThrowIfRecordTypeHasValidationOnProperties() internal static Func? FindTryParseMethod(Type modelType) { + if (modelType.IsByRef) + { + // ByRef is no supported in this case and + // will be reported later for the user. + return null; + } + modelType = Nullable.GetUnderlyingType(modelType) ?? modelType; return ParameterBindingMethodCache.FindTryParseMethod(modelType); } diff --git a/src/Mvc/Mvc.Core/test/ModelBinding/Metadata/DefaultModelMetadataTest.cs b/src/Mvc/Mvc.Core/test/ModelBinding/Metadata/DefaultModelMetadataTest.cs index 9eb9682e6b96..a5d639ef8c93 100644 --- a/src/Mvc/Mvc.Core/test/ModelBinding/Metadata/DefaultModelMetadataTest.cs +++ b/src/Mvc/Mvc.Core/test/ModelBinding/Metadata/DefaultModelMetadataTest.cs @@ -389,6 +389,27 @@ public void IsParseableType_ReturnsFalse_ForNonParseableTypes(Type modelType) Assert.False(isParseableType); } + [Theory] + [InlineData(typeof(string))] + [InlineData(typeof(int))] + public void IsParseableType_ReturnsFalse_ForByRefTypes(Type modelType) + { + // Arrange + var provider = new EmptyModelMetadataProvider(); + var detailsProvider = new EmptyCompositeMetadataDetailsProvider(); + + var key = ModelMetadataIdentity.ForType(modelType.MakeByRefType()); + var cache = new DefaultMetadataDetails(key, new ModelAttributes(Array.Empty(), null, null)); + + var metadata = new DefaultModelMetadata(provider, detailsProvider, cache); + + // Act + var isParseableType = metadata.IsParseableType; + + // Assert + Assert.False(isParseableType); + } + [Theory] [InlineData(typeof(string))] [InlineData(typeof(IDisposable))] diff --git a/src/ProjectTemplates/Web.ProjectTemplates/content/EmptyComponentsWebAssembly-CSharp/Client/MainLayout.razor b/src/ProjectTemplates/Web.ProjectTemplates/content/EmptyComponentsWebAssembly-CSharp/Client/MainLayout.razor index a5af3489ae82..de2be6c7fced 100644 --- a/src/ProjectTemplates/Web.ProjectTemplates/content/EmptyComponentsWebAssembly-CSharp/Client/MainLayout.razor +++ b/src/ProjectTemplates/Web.ProjectTemplates/content/EmptyComponentsWebAssembly-CSharp/Client/MainLayout.razor @@ -1,3 +1,5 @@ @inherits LayoutComponentBase -
@Body
+
+ @Body +
diff --git a/src/ProjectTemplates/Web.ProjectTemplates/content/EmptyComponentsWebAssembly-CSharp/Client/wwwroot/css/app.css b/src/ProjectTemplates/Web.ProjectTemplates/content/EmptyComponentsWebAssembly-CSharp/Client/wwwroot/css/app.css index 3afadc202ee5..ffcb043c0e50 100644 --- a/src/ProjectTemplates/Web.ProjectTemplates/content/EmptyComponentsWebAssembly-CSharp/Client/wwwroot/css/app.css +++ b/src/ProjectTemplates/Web.ProjectTemplates/content/EmptyComponentsWebAssembly-CSharp/Client/wwwroot/css/app.css @@ -1,3 +1,7 @@ +h1:focus { + outline: none; +} + #blazor-error-ui { background: lightyellow; bottom: 0; diff --git a/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts b/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts index 676aa698d192..2776e39895e9 100644 --- a/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts +++ b/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts @@ -376,7 +376,13 @@ describe("hubConnection", () => { try { await hubConnection.start(); } catch (error) { - expect(error!.message).toEqual(expectedErrorMessage); + if (error!.message.includes("404")) { + // SSE can race with the connection closing and the initial ping being successful or failing with a 404. + // LongPolling doesn't have pings and WebSockets is a synchronous API over a single HTTP request so it doesn't have the same issues + expect(error!.message).toEqual("No Connection with that ID: Status code '404'"); + } else { + expect(error!.message).toEqual(expectedErrorMessage); + } closePromise.resolve(); } await closePromise; diff --git a/src/SignalR/clients/ts/signalr/src/AccessTokenHttpClient.ts b/src/SignalR/clients/ts/signalr/src/AccessTokenHttpClient.ts new file mode 100644 index 000000000000..e28c98862af6 --- /dev/null +++ b/src/SignalR/clients/ts/signalr/src/AccessTokenHttpClient.ts @@ -0,0 +1,56 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +import { HeaderNames } from "./HeaderNames"; +import { HttpClient, HttpRequest, HttpResponse } from "./HttpClient"; + +/** @private */ +export class AccessTokenHttpClient extends HttpClient { + private _innerClient: HttpClient; + _accessToken: string | undefined; + _accessTokenFactory: (() => string | Promise) | undefined; + + constructor(innerClient: HttpClient, accessTokenFactory: (() => string | Promise) | undefined) { + super(); + + this._innerClient = innerClient; + this._accessTokenFactory = accessTokenFactory; + } + + public async send(request: HttpRequest): Promise { + let allowRetry = true; + if (this._accessTokenFactory && (!this._accessToken || (request.url && request.url.indexOf("/negotiate?") > 0))) { + // don't retry if the request is a negotiate or if we just got a potentially new token from the access token factory + allowRetry = false; + this._accessToken = await this._accessTokenFactory(); + } + this._setAuthorizationHeader(request); + const response = await this._innerClient.send(request); + + if (allowRetry && response.statusCode === 401 && this._accessTokenFactory) { + this._accessToken = await this._accessTokenFactory(); + this._setAuthorizationHeader(request); + return await this._innerClient.send(request); + } + return response; + } + + private _setAuthorizationHeader(request: HttpRequest) { + if (!request.headers) { + request.headers = {}; + } + if (this._accessToken) { + request.headers[HeaderNames.Authorization] = `Bearer ${this._accessToken}` + } + // don't remove the header if there isn't an access token factory, the user manually added the header in this case + else if (this._accessTokenFactory) { + if (request.headers[HeaderNames.Authorization]) { + delete request.headers[HeaderNames.Authorization]; + } + } + } + + public getCookieString(url: string): string { + return this._innerClient.getCookieString(url); + } +} \ No newline at end of file diff --git a/src/SignalR/clients/ts/signalr/src/HttpConnection.ts b/src/SignalR/clients/ts/signalr/src/HttpConnection.ts index 2cc7c14916f1..02e8d79aaed7 100644 --- a/src/SignalR/clients/ts/signalr/src/HttpConnection.ts +++ b/src/SignalR/clients/ts/signalr/src/HttpConnection.ts @@ -1,10 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +import { AccessTokenHttpClient } from "./AccessTokenHttpClient"; import { DefaultHttpClient } from "./DefaultHttpClient"; import { AggregateErrors, DisabledTransportError, FailedToNegotiateWithServerError, FailedToStartTransportError, HttpError, UnsupportedTransportError, AbortError } from "./Errors"; -import { HeaderNames } from "./HeaderNames"; -import { HttpClient } from "./HttpClient"; import { IConnection } from "./IConnection"; import { IHttpConnectionOptions } from "./IHttpConnectionOptions"; import { ILogger, LogLevel } from "./ILogger"; @@ -47,7 +46,7 @@ export class HttpConnection implements IConnection { // connectionStarted is tracked independently from connectionState, so we can check if the // connection ever did successfully transition from connecting to connected before disconnecting. private _connectionStarted: boolean; - private readonly _httpClient: HttpClient; + private readonly _httpClient: AccessTokenHttpClient; private readonly _logger: ILogger; private readonly _options: IHttpConnectionOptions; // Needs to not start with _ to be available for tests @@ -110,7 +109,7 @@ export class HttpConnection implements IConnection { } } - this._httpClient = options.httpClient || new DefaultHttpClient(this._logger); + this._httpClient = new AccessTokenHttpClient(options.httpClient || new DefaultHttpClient(this._logger), options.accessTokenFactory); this._connectionState = ConnectionState.Disconnected; this._connectionStarted = false; this._options = options; @@ -227,6 +226,7 @@ export class HttpConnection implements IConnection { // as part of negotiating let url = this.baseUrl; this._accessTokenFactory = this._options.accessTokenFactory; + this._httpClient._accessTokenFactory = this._accessTokenFactory; try { if (this._options.skipNegotiation) { @@ -267,6 +267,9 @@ export class HttpConnection implements IConnection { // the returned access token const accessToken = negotiateResponse.accessToken; this._accessTokenFactory = () => accessToken; + // set the factory to undefined so the AccessTokenHttpClient won't retry with the same token, since we know it won't change until a connection restart + this._httpClient._accessToken = accessToken; + this._httpClient._accessTokenFactory = undefined; } redirects++; @@ -307,13 +310,6 @@ export class HttpConnection implements IConnection { private async _getNegotiationResponse(url: string): Promise { const headers: {[k: string]: string} = {}; - if (this._accessTokenFactory) { - const token = await this._accessTokenFactory(); - if (token) { - headers[HeaderNames.Authorization] = `Bearer ${token}`; - } - } - const [name, value] = getUserAgentHeader(); headers[name] = value; @@ -424,9 +420,9 @@ export class HttpConnection implements IConnection { if (!this._options.EventSource) { throw new Error("'EventSource' is not supported in your environment."); } - return new ServerSentEventsTransport(this._httpClient, this._accessTokenFactory, this._logger, this._options); + return new ServerSentEventsTransport(this._httpClient, this._httpClient._accessToken, this._logger, this._options); case HttpTransportType.LongPolling: - return new LongPollingTransport(this._httpClient, this._accessTokenFactory, this._logger, this._options); + return new LongPollingTransport(this._httpClient, this._logger, this._options); default: throw new Error(`Unknown transport: ${transport}.`); } diff --git a/src/SignalR/clients/ts/signalr/src/HubConnection.ts b/src/SignalR/clients/ts/signalr/src/HubConnection.ts index 67c32aef1fa1..3319af22f1b7 100644 --- a/src/SignalR/clients/ts/signalr/src/HubConnection.ts +++ b/src/SignalR/clients/ts/signalr/src/HubConnection.ts @@ -235,6 +235,10 @@ export class HubConnection { // eslint-disable-next-line @typescript-eslint/no-throw-literal throw this._stopDuringStartError; } + + if (!this.connection.features.inherentKeepAlive) { + await this._sendMessage(this._cachedPingMessage); + } } catch (e) { this._logger.log(LogLevel.Debug, `Hub handshake failed with error '${e}' during start(). Stopping HubConnection.`); diff --git a/src/SignalR/clients/ts/signalr/src/LongPollingTransport.ts b/src/SignalR/clients/ts/signalr/src/LongPollingTransport.ts index 6cd06b2a5b60..e4fbd8a36f18 100644 --- a/src/SignalR/clients/ts/signalr/src/LongPollingTransport.ts +++ b/src/SignalR/clients/ts/signalr/src/LongPollingTransport.ts @@ -3,7 +3,6 @@ import { AbortController } from "./AbortController"; import { HttpError, TimeoutError } from "./Errors"; -import { HeaderNames } from "./HeaderNames"; import { HttpClient, HttpRequest } from "./HttpClient"; import { ILogger, LogLevel } from "./ILogger"; import { ITransport, TransferFormat } from "./ITransport"; @@ -14,7 +13,6 @@ import { IHttpConnectionOptions } from "./IHttpConnectionOptions"; /** @private */ export class LongPollingTransport implements ITransport { private readonly _httpClient: HttpClient; - private readonly _accessTokenFactory: (() => string | Promise) | undefined; private readonly _logger: ILogger; private readonly _options: IHttpConnectionOptions; private readonly _pollAbort: AbortController; @@ -32,9 +30,8 @@ export class LongPollingTransport implements ITransport { return this._pollAbort.aborted; } - constructor(httpClient: HttpClient, accessTokenFactory: (() => string | Promise) | undefined, logger: ILogger, options: IHttpConnectionOptions) { + constructor(httpClient: HttpClient, logger: ILogger, options: IHttpConnectionOptions) { this._httpClient = httpClient; - this._accessTokenFactory = accessTokenFactory; this._logger = logger; this._pollAbort = new AbortController(); this._options = options; @@ -74,8 +71,6 @@ export class LongPollingTransport implements ITransport { pollOptions.responseType = "arraybuffer"; } - await this._updateHeaderToken(pollOptions); - // Make initial long polling request // Server uses first long polling request to finish initializing connection and it returns without data const pollUrl = `${url}&_=${Date.now()}`; @@ -94,28 +89,9 @@ export class LongPollingTransport implements ITransport { this._receiving = this._poll(this._url, pollOptions); } - private async _updateHeaderToken(request: HttpRequest): Promise { - if (!request.headers) { - request.headers = {}; - } - if (this._accessTokenFactory) { - const token = await this._accessTokenFactory(); - if (token) { - request.headers[HeaderNames.Authorization] = `Bearer ${token}` - } else { - if (request.headers[HeaderNames.Authorization]) { - delete request.headers[HeaderNames.Authorization]; - } - } - } - } - private async _poll(url: string, pollOptions: HttpRequest): Promise { try { while (this._running) { - // We have to get the access token on each poll, in case it changes - await this._updateHeaderToken(pollOptions); - try { const pollUrl = `${url}&_=${Date.now()}`; this._logger.log(LogLevel.Trace, `(LongPolling transport) polling: ${pollUrl}.`); @@ -174,7 +150,7 @@ export class LongPollingTransport implements ITransport { if (!this._running) { return Promise.reject(new Error("Cannot send until the transport is connected")); } - return sendMessage(this._logger, "LongPolling", this._httpClient, this._url!, this._accessTokenFactory, data, this._options); + return sendMessage(this._logger, "LongPolling", this._httpClient, this._url!, data, this._options); } public async stop(): Promise { @@ -199,7 +175,6 @@ export class LongPollingTransport implements ITransport { timeout: this._options.timeout, withCredentials: this._options.withCredentials, }; - await this._updateHeaderToken(deleteOptions); await this._httpClient.delete(this._url!, deleteOptions); this._logger.log(LogLevel.Trace, "(LongPolling transport) DELETE request sent."); diff --git a/src/SignalR/clients/ts/signalr/src/ServerSentEventsTransport.ts b/src/SignalR/clients/ts/signalr/src/ServerSentEventsTransport.ts index 7ac681992ac7..e147e163975d 100644 --- a/src/SignalR/clients/ts/signalr/src/ServerSentEventsTransport.ts +++ b/src/SignalR/clients/ts/signalr/src/ServerSentEventsTransport.ts @@ -11,7 +11,7 @@ import { IHttpConnectionOptions } from "./IHttpConnectionOptions"; /** @private */ export class ServerSentEventsTransport implements ITransport { private readonly _httpClient: HttpClient; - private readonly _accessTokenFactory: (() => string | Promise) | undefined; + private readonly _accessToken: string | undefined; private readonly _logger: ILogger; private readonly _options: IHttpConnectionOptions; private _eventSource?: EventSource; @@ -20,10 +20,10 @@ export class ServerSentEventsTransport implements ITransport { public onreceive: ((data: string | ArrayBuffer) => void) | null; public onclose: ((error?: Error) => void) | null; - constructor(httpClient: HttpClient, accessTokenFactory: (() => string | Promise) | undefined, logger: ILogger, + constructor(httpClient: HttpClient, accessToken: string | undefined, logger: ILogger, options: IHttpConnectionOptions) { this._httpClient = httpClient; - this._accessTokenFactory = accessTokenFactory; + this._accessToken = accessToken; this._logger = logger; this._options = options; @@ -38,14 +38,11 @@ export class ServerSentEventsTransport implements ITransport { this._logger.log(LogLevel.Trace, "(SSE transport) Connecting."); - // set url before accessTokenFactory because this.url is only for send and we set the auth header instead of the query string for send + // set url before accessTokenFactory because this._url is only for send and we set the auth header instead of the query string for send this._url = url; - if (this._accessTokenFactory) { - const token = await this._accessTokenFactory(); - if (token) { - url += (url.indexOf("?") < 0 ? "?" : "&") + `access_token=${encodeURIComponent(token)}`; - } + if (this._accessToken) { + url += (url.indexOf("?") < 0 ? "?" : "&") + `access_token=${encodeURIComponent(this._accessToken)}`; } return new Promise((resolve, reject) => { @@ -111,7 +108,7 @@ export class ServerSentEventsTransport implements ITransport { if (!this._eventSource) { return Promise.reject(new Error("Cannot send until the transport is connected")); } - return sendMessage(this._logger, "SSE", this._httpClient, this._url!, this._accessTokenFactory, data, this._options); + return sendMessage(this._logger, "SSE", this._httpClient, this._url!, data, this._options); } public stop(): Promise { diff --git a/src/SignalR/clients/ts/signalr/src/Utils.ts b/src/SignalR/clients/ts/signalr/src/Utils.ts index c5d19bfa9105..b60770b99a48 100644 --- a/src/SignalR/clients/ts/signalr/src/Utils.ts +++ b/src/SignalR/clients/ts/signalr/src/Utils.ts @@ -99,17 +99,9 @@ export function isArrayBuffer(val: any): val is ArrayBuffer { } /** @private */ -export async function sendMessage(logger: ILogger, transportName: string, httpClient: HttpClient, url: string, accessTokenFactory: (() => string | Promise) | undefined, +export async function sendMessage(logger: ILogger, transportName: string, httpClient: HttpClient, url: string, content: string | ArrayBuffer, options: IHttpConnectionOptions): Promise { - let headers: {[k: string]: string} = {}; - if (accessTokenFactory) { - const token = await accessTokenFactory(); - if (token) { - headers = { - ["Authorization"]: `Bearer ${token}`, - }; - } - } + const headers: {[k: string]: string} = {}; const [name, value] = getUserAgentHeader(); headers[name] = value; diff --git a/src/SignalR/clients/ts/signalr/tests/HttpConnection.test.ts b/src/SignalR/clients/ts/signalr/tests/HttpConnection.test.ts index b6dc4123f076..57a4aa6b9a0d 100644 --- a/src/SignalR/clients/ts/signalr/tests/HttpConnection.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/HttpConnection.test.ts @@ -668,6 +668,7 @@ describe("HttpConnection", () => { await VerifyLogger.run(async (logger) => { let firstNegotiate = true; let firstPoll = true; + const pollingPromiseSource = new PromiseSource(); const httpClient = new TestHttpClient() .on("POST", /\/negotiate/, (r) => { if (firstNegotiate) { @@ -689,7 +690,7 @@ describe("HttpConnection", () => { connectionId: "0rge0d00-0040-0030-0r00-000q00r00e00", }; }) - .on("GET", (r) => { + .on("GET", async (r) => { if (r.headers && r.headers.Authorization !== "Bearer secondSecret") { return new HttpResponse(401, "Unauthorized", ""); } @@ -698,6 +699,7 @@ describe("HttpConnection", () => { firstPoll = false; return ""; } + await pollingPromiseSource.promise; return new HttpResponse(204, "No Content", ""); }) .on("DELETE", () => new HttpResponse(202)); @@ -719,6 +721,8 @@ describe("HttpConnection", () => { expect(httpClient.sentRequests[1].url).toBe("https://another.domain.url/chat/negotiate?negotiateVersion=1"); expect(httpClient.sentRequests[2].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i); expect(httpClient.sentRequests[3].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i); + + pollingPromiseSource.resolve(); } finally { await connection.stop(); } @@ -768,8 +772,11 @@ describe("HttpConnection", () => { httpClientGetCount++; const authorizationValue = r.headers![HeaderNames.Authorization]; if (httpClientGetCount === 1) { + // Auth failure to cause a retry with a call to accessTokenFactory + return new HttpResponse(401); + } else if (httpClientGetCount === 2) { if (authorizationValue) { - fail("First long poll request should have a authorization header."); + fail("First long poll request should have no authorization header."); } // First long polling request must succeed so start completes return ""; diff --git a/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts b/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts index 4b8614b5ed0c..630e3f3f0860 100644 --- a/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts @@ -31,7 +31,7 @@ describe("HubConnection", () => { const hubConnection = createHubConnection(connection, logger); try { await hubConnection.start(); - expect(connection.sentData.length).toBe(1); + expect(connection.sentData.length).toBe(2); expect(JSON.parse(connection.sentData[0])).toEqual({ protocol: "json", version: 1, @@ -448,7 +448,7 @@ describe("HubConnection", () => { const subject = new Subject(); const invokePromise = hubConnection.invoke("testMethod", "arg", subject); - expect(JSON.parse(connection.sentData[1])).toEqual({ + expect(JSON.parse(connection.sentData[2])).toEqual({ arguments: ["arg"], invocationId: "1", streamIds: ["0"], @@ -460,7 +460,7 @@ describe("HubConnection", () => { await new Promise((resolve) => { setTimeout(resolve, 50); }); - expect(JSON.parse(connection.sentData[2])).toEqual({ + expect(JSON.parse(connection.sentData[3])).toEqual({ invocationId: "0", item: "item numero uno", type: MessageType.StreamItem, @@ -485,7 +485,7 @@ describe("HubConnection", () => { const subject = new Subject(); await hubConnection.send("testMethod", "arg", subject); - expect(JSON.parse(connection.sentData[1])).toEqual({ + expect(JSON.parse(connection.sentData[2])).toEqual({ arguments: ["arg"], streamIds: ["0"], target: "testMethod", @@ -496,7 +496,7 @@ describe("HubConnection", () => { await new Promise((resolve) => { setTimeout(resolve, 50); }); - expect(JSON.parse(connection.sentData[2])).toEqual({ + expect(JSON.parse(connection.sentData[3])).toEqual({ invocationId: "0", item: "item numero uno", type: MessageType.StreamItem, @@ -528,7 +528,7 @@ describe("HubConnection", () => { }, }); - expect(JSON.parse(connection.sentData[1])).toEqual({ + expect(JSON.parse(connection.sentData[2])).toEqual({ arguments: ["arg"], invocationId: "1", streamIds: ["0"], @@ -540,7 +540,7 @@ describe("HubConnection", () => { await new Promise((resolve) => { setTimeout(resolve, 50); }); - expect(JSON.parse(connection.sentData[2])).toEqual({ + expect(JSON.parse(connection.sentData[3])).toEqual({ invocationId: "0", item: "item numero uno", type: MessageType.StreamItem, @@ -1102,10 +1102,10 @@ describe("HubConnection", () => { // async here to guarantee the sent message is written await delayUntil(1); - expect(connection.parsedSentData.length).toEqual(2); - expect(connection.parsedSentData[1].type).toEqual(3); - expect(connection.parsedSentData[1].result).toEqual(10); - expect(connection.parsedSentData[1].invocationId).toEqual("1"); + expect(connection.parsedSentData.length).toEqual(3); + expect(connection.parsedSentData[2].type).toEqual(3); + expect(connection.parsedSentData[2].result).toEqual(10); + expect(connection.parsedSentData[2].invocationId).toEqual("1"); } finally { await hubConnection.stop(); } @@ -1133,10 +1133,10 @@ describe("HubConnection", () => { // async here to guarantee the sent message is written await delayUntil(1); - expect(connection.parsedSentData.length).toEqual(2); - expect(connection.parsedSentData[1].type).toEqual(3); - expect(connection.parsedSentData[1].result).toBeNull(); - expect(connection.parsedSentData[1].invocationId).toEqual("1"); + expect(connection.parsedSentData.length).toEqual(3); + expect(connection.parsedSentData[2].type).toEqual(3); + expect(connection.parsedSentData[2].result).toBeNull(); + expect(connection.parsedSentData[2].invocationId).toEqual("1"); } finally { await hubConnection.stop(); } @@ -1166,10 +1166,10 @@ describe("HubConnection", () => { // async here to guarantee the sent message is written await delayUntil(1); - expect(connection.parsedSentData.length).toEqual(2); - expect(connection.parsedSentData[1].type).toEqual(3); - expect(connection.parsedSentData[1].result).toEqual(13); - expect(connection.parsedSentData[1].invocationId).toEqual("1"); + expect(connection.parsedSentData.length).toEqual(3); + expect(connection.parsedSentData[2].type).toEqual(3); + expect(connection.parsedSentData[2].result).toEqual(13); + expect(connection.parsedSentData[2].invocationId).toEqual("1"); } finally { await hubConnection.stop(); } @@ -1197,10 +1197,10 @@ describe("HubConnection", () => { // async here to guarantee the sent message is written await delayUntil(1); - expect(connection.parsedSentData.length).toEqual(2); - expect(connection.parsedSentData[1].type).toEqual(3); - expect(connection.parsedSentData[1].error).toEqual("Error: from callback"); - expect(connection.parsedSentData[1].invocationId).toEqual("1"); + expect(connection.parsedSentData.length).toEqual(3); + expect(connection.parsedSentData[2].type).toEqual(3); + expect(connection.parsedSentData[2].error).toEqual("Error: from callback"); + expect(connection.parsedSentData[2].invocationId).toEqual("1"); } finally { await hubConnection.stop(); } @@ -1229,10 +1229,10 @@ describe("HubConnection", () => { // async here to guarantee the sent message is written await delayUntil(1); - expect(connection.parsedSentData.length).toEqual(2); - expect(connection.parsedSentData[1].type).toEqual(3); - expect(connection.parsedSentData[1].error).toEqual('Client provided multiple results.'); - expect(connection.parsedSentData[1].invocationId).toEqual("1"); + expect(connection.parsedSentData.length).toEqual(3); + expect(connection.parsedSentData[2].type).toEqual(3); + expect(connection.parsedSentData[2].error).toEqual('Client provided multiple results.'); + expect(connection.parsedSentData[2].invocationId).toEqual("1"); } finally { await hubConnection.stop(); } @@ -1261,11 +1261,11 @@ describe("HubConnection", () => { // async here to guarantee the sent message is written await delayUntil(1); - expect(connection.parsedSentData.length).toEqual(2); - expect(connection.parsedSentData[1].type).toEqual(3); - expect(connection.parsedSentData[1].error).toEqual("Error: from callback"); - expect(connection.parsedSentData[1].result).toBeUndefined(); - expect(connection.parsedSentData[1].invocationId).toEqual("1"); + expect(connection.parsedSentData.length).toEqual(3); + expect(connection.parsedSentData[2].type).toEqual(3); + expect(connection.parsedSentData[2].error).toEqual("Error: from callback"); + expect(connection.parsedSentData[2].result).toBeUndefined(); + expect(connection.parsedSentData[2].invocationId).toEqual("1"); } finally { await hubConnection.stop(); } @@ -1294,11 +1294,11 @@ describe("HubConnection", () => { // async here to guarantee the sent message is written await delayUntil(1); - expect(connection.parsedSentData.length).toEqual(2); - expect(connection.parsedSentData[1].type).toEqual(3); - expect(connection.parsedSentData[1].result).toEqual(3); - expect(connection.parsedSentData[1].error).toBeUndefined(); - expect(connection.parsedSentData[1].invocationId).toEqual("1"); + expect(connection.parsedSentData.length).toEqual(3); + expect(connection.parsedSentData[2].type).toEqual(3); + expect(connection.parsedSentData[2].result).toEqual(3); + expect(connection.parsedSentData[2].error).toBeUndefined(); + expect(connection.parsedSentData[2].invocationId).toEqual("1"); } finally { await hubConnection.stop(); } @@ -1326,10 +1326,10 @@ describe("HubConnection", () => { // async here to guarantee the sent message is written await delayUntil(1); - expect(connection.parsedSentData.length).toEqual(2); - expect(connection.parsedSentData[1].type).toEqual(3); - expect(connection.parsedSentData[1].error).toEqual("Client didn't provide a result."); - expect(connection.parsedSentData[1].invocationId).toEqual("1"); + expect(connection.parsedSentData.length).toEqual(3); + expect(connection.parsedSentData[2].type).toEqual(3); + expect(connection.parsedSentData[2].error).toEqual("Client didn't provide a result."); + expect(connection.parsedSentData[2].invocationId).toEqual("1"); } finally { await hubConnection.stop(); } @@ -1355,10 +1355,10 @@ describe("HubConnection", () => { // async here to guarantee the sent message is written await delayUntil(1); - expect(connection.parsedSentData.length).toEqual(2); - expect(connection.parsedSentData[1].type).toEqual(3); - expect(connection.parsedSentData[1].error).toEqual("Client didn't provide a result."); - expect(connection.parsedSentData[1].invocationId).toEqual("1"); + expect(connection.parsedSentData.length).toEqual(3); + expect(connection.parsedSentData[2].type).toEqual(3); + expect(connection.parsedSentData[2].error).toEqual("Client didn't provide a result."); + expect(connection.parsedSentData[2].invocationId).toEqual("1"); } finally { await hubConnection.stop(); } @@ -1386,7 +1386,7 @@ describe("HubConnection", () => { // async here to guarantee the sent message is written await delayUntil(1); - expect(connection.parsedSentData.length).toEqual(1); + expect(connection.parsedSentData.length).toEqual(2); } finally { await hubConnection.stop(); } @@ -1405,9 +1405,9 @@ describe("HubConnection", () => { hubConnection.stream("testStream", "arg", 42); - // Verify the message is sent (+ handshake) - expect(connection.sentData.length).toBe(2); - expect(JSON.parse(connection.sentData[1])).toEqual({ + // Verify the message is sent (+ handshake + ping) + expect(connection.sentData.length).toBe(3); + expect(JSON.parse(connection.sentData[2])).toEqual({ arguments: [ "arg", 42, @@ -1416,9 +1416,6 @@ describe("HubConnection", () => { target: "testStream", type: MessageType.StreamInvocation, }); - - // Close the connection - await hubConnection.stop(); } finally { await hubConnection.stop(); } @@ -1592,10 +1589,10 @@ describe("HubConnection", () => { expect(observer.itemsReceived).toEqual([1]); // Close message sent asynchronously so we need to wait - await delayUntil(1000, () => connection.sentData.length === 3); + await delayUntil(1000, () => connection.sentData.length === 4); // Verify the cancel is sent (+ handshake) - expect(connection.sentData.length).toBe(3); - expect(JSON.parse(connection.sentData[2])).toEqual({ + expect(connection.sentData.length).toBe(4); + expect(JSON.parse(connection.sentData[3])).toEqual({ invocationId: connection.lastInvocationId, type: MessageType.CancelInvocation, }); @@ -1830,7 +1827,9 @@ class TestProtocol implements IHubProtocol { } public writeMessage(message: HubMessage): any { - + if (message.type === 6) { + return "{\"type\": 6}" + TextMessageFormat.RecordSeparator; + } } } diff --git a/src/SignalR/clients/ts/signalr/tests/LongPollingTransport.test.ts b/src/SignalR/clients/ts/signalr/tests/LongPollingTransport.test.ts index 37ea48b8168e..1d985fa7f7e2 100644 --- a/src/SignalR/clients/ts/signalr/tests/LongPollingTransport.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/LongPollingTransport.test.ts @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +import { AccessTokenHttpClient } from "../src/AccessTokenHttpClient"; import { HttpResponse } from "../src/HttpClient"; import { TransferFormat } from "../src/ITransport"; import { LongPollingTransport } from "../src/LongPollingTransport"; @@ -40,7 +41,7 @@ describe("LongPollingTransport", () => { } }) .on("DELETE", () => new HttpResponse(202)); - const transport = new LongPollingTransport(client, undefined, logger, { logMessageContent: false, withCredentials: true, headers: {} }); + const transport = new LongPollingTransport(client, logger, { logMessageContent: false, withCredentials: true, headers: {} }); await transport.connect("http://example.com", TransferFormat.Text); const stopPromise = transport.stop(); @@ -64,7 +65,7 @@ describe("LongPollingTransport", () => { return new HttpResponse(204); } }); - const transport = new LongPollingTransport(client, undefined, logger, { logMessageContent: false, withCredentials: true, headers: {} }); + const transport = new LongPollingTransport(client, logger, { logMessageContent: false, withCredentials: true, headers: {} }); const stopPromise = makeClosedPromise(transport); @@ -97,7 +98,7 @@ describe("LongPollingTransport", () => { return new HttpResponse(202); }); - const transport = new LongPollingTransport(httpClient, undefined, logger, { logMessageContent: false, withCredentials: true, headers: {} }); + const transport = new LongPollingTransport(httpClient, logger, { logMessageContent: false, withCredentials: true, headers: {} }); await transport.connect("http://tempuri.org", TransferFormat.Text); @@ -146,7 +147,7 @@ describe("LongPollingTransport", () => { return new HttpResponse(202); }); - const transport = new LongPollingTransport(httpClient, undefined, logger, { logMessageContent: false, withCredentials: true, headers: {} }); + const transport = new LongPollingTransport(httpClient, logger, { logMessageContent: false, withCredentials: true, headers: {} }); await transport.connect("http://tempuri.org", TransferFormat.Text); @@ -203,7 +204,7 @@ describe("LongPollingTransport", () => { return new HttpResponse(202); }); - const transport = new LongPollingTransport(httpClient, undefined, logger, { logMessageContent: false, withCredentials: true, headers }); + const transport = new LongPollingTransport(httpClient, logger, { logMessageContent: false, withCredentials: true, headers }); await transport.connect("http://tempuri.org", TransferFormat.Text); @@ -251,7 +252,7 @@ describe("LongPollingTransport", () => { return new HttpResponse(202); }); - const transport = new LongPollingTransport(httpClient, undefined, logger, + const transport = new LongPollingTransport(httpClient, logger, { logMessageContent: false, withCredentials: true, headers: {}, timeout: 123 }); await transport.connect("http://tempuri.org", TransferFormat.Text); @@ -272,17 +273,31 @@ describe("LongPollingTransport", () => { it("removes Authorization header when accessTokenFactory returns empty", async () => { await VerifyLogger.run(async (logger) => { let firstPoll = true; + let secondPoll = false; let firstAuthHeader = ""; let secondAuthHeader = ""; let deleteAuthHeader = ""; + const accessTokenFactory = () => { + if (firstAuthHeader) { + return ""; + } + return "value"; + }; const pollingPromiseSource = new PromiseSource(); - const httpClient = new TestHttpClient() + const readyToStopPromiseSource = new PromiseSource(); + const httpClient = new AccessTokenHttpClient(new TestHttpClient() .on("GET", async (r) => { if (firstPoll) { firstPoll = false; + secondPoll = true; firstAuthHeader = r.headers!.Authorization; return new HttpResponse(200); + } else if (secondPoll) { + secondPoll = false; + // force a retry so the access token factory is called again + return new HttpResponse(401); } else { + readyToStopPromiseSource.resolve(); secondAuthHeader = r.headers!.Authorization; await pollingPromiseSource.promise; return new HttpResponse(204); @@ -291,17 +306,14 @@ describe("LongPollingTransport", () => { .on("DELETE", async (r) => { deleteAuthHeader = r.headers!.Authorization; return new HttpResponse(202); - }); + }), accessTokenFactory); - const transport = new LongPollingTransport(httpClient, () => { - if (firstAuthHeader) { - return ""; - } - return "value"; - }, logger, { logMessageContent: false, withCredentials: true, headers: {} }); + const transport = new LongPollingTransport(httpClient, logger, { logMessageContent: false, withCredentials: true, headers: {} }); await transport.connect("http://tempuri.org", TransferFormat.Text); + await readyToStopPromiseSource.promise; + // Begin stopping transport const stopPromise = transport.stop(); @@ -324,7 +336,7 @@ describe("LongPollingTransport", () => { let secondAuthHeader = ""; let deleteAuthHeader = ""; const pollingPromiseSource = new PromiseSource(); - const httpClient = new TestHttpClient() + const httpClient = new AccessTokenHttpClient(new TestHttpClient() .on("GET", async (r) => { if (firstPoll) { firstPoll = false; @@ -339,9 +351,9 @@ describe("LongPollingTransport", () => { .on("DELETE", async (r) => { deleteAuthHeader = r.headers!.Authorization; return new HttpResponse(202); - }); + }), undefined); - const transport = new LongPollingTransport(httpClient, undefined, logger, { logMessageContent: false, withCredentials: true, + const transport = new LongPollingTransport(httpClient, logger, { logMessageContent: false, withCredentials: true, headers: { Authorization: "Basic test" } }); await transport.connect("http://tempuri.org", TransferFormat.Text); diff --git a/src/SignalR/clients/ts/signalr/tests/ServerSentEventsTransport.test.ts b/src/SignalR/clients/ts/signalr/tests/ServerSentEventsTransport.test.ts index 424939f6afeb..7a9f2c0f9082 100644 --- a/src/SignalR/clients/ts/signalr/tests/ServerSentEventsTransport.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/ServerSentEventsTransport.test.ts @@ -4,7 +4,7 @@ import { MessageHeaders } from "../src/IHubProtocol"; import { TransferFormat } from "../src/ITransport"; -import { HttpClient, HttpRequest } from "../src/HttpClient"; +import { HttpClient, HttpRequest, HttpResponse } from "../src/HttpClient"; import { ILogger } from "../src/ILogger"; import { ServerSentEventsTransport } from "../src/ServerSentEventsTransport"; import { getUserAgentHeader } from "../src/Utils"; @@ -13,6 +13,7 @@ import { TestEventSource, TestMessageEvent } from "./TestEventSource"; import { TestHttpClient } from "./TestHttpClient"; import { registerUnhandledRejectionHandler } from "./Utils"; import { IHttpConnectionOptions } from "signalr/src/IHttpConnectionOptions"; +import { AccessTokenHttpClient } from "../src/AccessTokenHttpClient"; registerUnhandledRejectionHandler(); @@ -87,10 +88,10 @@ describe("ServerSentEventsTransport", () => { it("sets Authorization header on sends", async () => { await VerifyLogger.run(async (logger) => { let request: HttpRequest; - const httpClient = new TestHttpClient().on((r) => { + const httpClient = new AccessTokenHttpClient(new TestHttpClient().on((r) => { request = r; return ""; - }); + }), () => "secretToken"); const sse = await createAndStartSSE(logger, "http://example.com", () => "secretToken", { httpClient }); @@ -101,6 +102,35 @@ describe("ServerSentEventsTransport", () => { }); }); + it("retries 401 requests on sends", async () => { + await VerifyLogger.run(async (logger) => { + let request: HttpRequest; + let requestCount = 0; + const httpClient = new AccessTokenHttpClient(new TestHttpClient().on((r) => { + requestCount++; + if (requestCount === 2) { + return new HttpResponse(401); + } + request = r; + return ""; + }), () => "secretToken" + requestCount); + + // AccessTokenHttpClient assumes negotiate was called which would have called accessTokenFactory already + // It also assumes the request shouldn't be retried if the factory was called, so we need to make a "negotiate" call + // to test the retry behavior for send requests + await httpClient.post(""); + expect(request!.headers!.Authorization).toBe("Bearer secretToken0"); + + const sse = await createAndStartSSE(logger, "http://example.com", () => "secretToken", { httpClient }); + + await sse.send(""); + + expect(request!.headers!.Authorization).toBe("Bearer secretToken2"); + expect(request!.url).toBe("http://example.com"); + expect(requestCount).toEqual(3); + }); + }); + it("can send data", async () => { await VerifyLogger.run(async (logger) => { let request: HttpRequest; @@ -267,7 +297,12 @@ describe("ServerSentEventsTransport", () => { }); async function createAndStartSSE(logger: ILogger, url?: string, accessTokenFactory?: (() => string | Promise), options?: IHttpConnectionOptions): Promise { - const sse = new ServerSentEventsTransport(options?.httpClient || new TestHttpClient(), accessTokenFactory, logger, + let token; + // SSE assumes (correctly) that negotiate will already create the token we want to use on connection startup, so simulate that here when creating the SSE transport + if (accessTokenFactory) { + token = await accessTokenFactory(); + } + const sse = new ServerSentEventsTransport(options?.httpClient || new TestHttpClient(), token, logger, { logMessageContent: true, EventSource: TestEventSource, withCredentials: true, timeout: 10205, ...options }); const connectPromise = sse.connect(url || "http://example.com", TransferFormat.Text); diff --git a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs index 6d726e0c2c17..220aae9d7005 100644 --- a/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.Json/src/Protocol/JsonHubProtocol.cs @@ -238,7 +238,15 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) } else { - result = BindType(ref reader, input, returnType); + try + { + result = BindType(ref reader, input, returnType); + } + catch (Exception ex) + { + error = $"Error trying to deserialize result to {returnType.Name}. {ex.Message}"; + hasResult = false; + } } } } @@ -423,7 +431,15 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) } else { - result = BindType(ref resultToken, input, returnType); + try + { + result = BindType(ref resultToken, input, returnType); + } + catch (Exception ex) + { + error = $"Error trying to deserialize result to {returnType.Name}. {ex.Message}"; + hasResult = false; + } } } diff --git a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs index df0dc0c7a8a9..50fb12bf9e24 100644 --- a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs +++ b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs @@ -162,6 +162,7 @@ private CompletionMessage CreateCompletionMessage(ref MessagePackReader reader, error = ReadString(ref reader, "error"); break; case NonVoidResult: + hasResult = true; var itemType = ProtocolHelper.TryGetReturnType(binder, invocationId); if (itemType is null) { @@ -175,10 +176,17 @@ private CompletionMessage CreateCompletionMessage(ref MessagePackReader reader, } else { - result = DeserializeObject(ref reader, itemType, "argument"); + try + { + result = DeserializeObject(ref reader, itemType, "argument"); + } + catch (Exception ex) + { + error = $"Error trying to deserialize result to {itemType.Name}. {ex.Message}"; + hasResult = false; + } } } - hasResult = true; break; case VoidResult: hasResult = false; diff --git a/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs b/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs index 11dd9d107adb..b0083731f976 100644 --- a/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs +++ b/src/SignalR/common/Protocols.NewtonsoftJson/src/Protocol/NewtonsoftJsonHubProtocol.cs @@ -229,7 +229,15 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) } else { - result = PayloadSerializer.Deserialize(reader, returnType); + try + { + result = PayloadSerializer.Deserialize(reader, returnType); + } + catch (Exception ex) + { + error = $"Error trying to deserialize result to {returnType.Name}. {ex.Message}"; + hasResult = false; + } } } } @@ -417,7 +425,15 @@ public ReadOnlyMemory GetMessageBytes(HubMessage message) } else { - result = resultToken.ToObject(returnType, PayloadSerializer); + try + { + result = resultToken.ToObject(returnType, PayloadSerializer); + } + catch (Exception ex) + { + error = $"Error trying to deserialize result to {returnType.Name}. {ex.Message}"; + hasResult = false; + } } } } diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/JsonHubProtocolTestsBase.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/JsonHubProtocolTestsBase.cs index 59e765acf0c4..b5ad57c61eb0 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/JsonHubProtocolTestsBase.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/JsonHubProtocolTestsBase.cs @@ -455,16 +455,35 @@ public void RawResultRoundTripsProperly(string testDataName) } } - [Fact] - public void UnexpectedClientResultGivesEmptyCompletionMessage() + [Theory] + [InlineData("{\"type\":3,\"result\":1,\"invocationId\":\"1\"}")] + [InlineData("{\"result\":1,\"type\":3,\"invocationId\":\"1\"}")] + public void UnexpectedClientResultGivesEmptyCompletionMessage(string input) { var binder = new TestBinder(); - var message = Frame("{\"type\":3,\"result\":1,\"invocationId\":\"1\"}"); + var message = Frame(input); + var data = new ReadOnlySequence(Encoding.UTF8.GetBytes(message)); + Assert.True(JsonHubProtocol.TryParseMessage(ref data, binder, out var hubMessage)); + + var completion = Assert.IsType(hubMessage); + Assert.Null(completion.Result); + Assert.Null(completion.Error); + Assert.Equal("1", completion.InvocationId); + } + + [Theory] + [InlineData("{\"type\":3,\"result\":\"string\",\"invocationId\":\"1\"}")] + [InlineData("{\"result\":\"string\",\"type\":3,\"invocationId\":\"1\"}")] + public void WrongTypeForClientResultGivesErrorCompletionMessage(string input) + { + var binder = new TestBinder(paramTypes: null, returnType: typeof(int)); + var message = Frame(input); var data = new ReadOnlySequence(Encoding.UTF8.GetBytes(message)); Assert.True(JsonHubProtocol.TryParseMessage(ref data, binder, out var hubMessage)); var completion = Assert.IsType(hubMessage); Assert.Null(completion.Result); + Assert.StartsWith("Error trying to deserialize result to Int32.", completion.Error); Assert.Equal("1", completion.InvocationId); } diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTestBase.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTestBase.cs index 25d936152d05..6ef8a626e13f 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTestBase.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTestBase.cs @@ -307,8 +307,10 @@ protected void TestWriteMessages(ProtocolTestData testData) new InvalidMessageData("CompletionResultKindOutOfRange", new byte[] { 0x94, 3, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 42 }, "Invalid invocation result kind."), new InvalidMessageData("CompletionErrorMissing", new byte[] { 0x94, 3, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 0x01 }, "Reading 'error' as String failed."), new InvalidMessageData("CompletionErrorInt", new byte[] { 0x95, 3, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 0x01, 42 }, "Reading 'error' as String failed."), - new InvalidMessageData("CompletionResultMissing", new byte[] { 0x94, 3, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 0x03 }, "Deserializing object of the `String` type for 'argument' failed."), - new InvalidMessageData("CompletionResultTypeMismatch", new byte[] { 0x95, 3, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 0x03, 42 }, "Deserializing object of the `String` type for 'argument' failed."), + + // These now result in CompletionMessages with the error field set + //new InvalidMessageData("CompletionResultMissing", new byte[] { 0x94, 3, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 0x03 }, "Deserializing object of the `String` type for 'argument' failed."), + //new InvalidMessageData("CompletionResultTypeMismatch", new byte[] { 0x95, 3, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 0x03, 42 }, "Deserializing object of the `String` type for 'argument' failed."), }.ToDictionary(t => t.Name); public static IEnumerable BaseInvalidPayloadNames => BaseInvalidPayloads.Keys.Select(name => new object[] { name }); diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs index 647f50cd3324..29fe23facd51 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs @@ -256,6 +256,20 @@ public void UnexpectedClientResultGivesEmptyCompletionMessage() Assert.Equal("xyz", completion.InvocationId); } + [Fact] + public void WrongTypeForClientResultGivesErrorCompletionMessage() + { + var binder = new TestBinder(paramTypes: null, returnType: typeof(int)); + var input = Frame(Convert.FromBase64String("lQOAo3h5egOmc3RyaW5n")); + var data = new ReadOnlySequence(input); + Assert.True(HubProtocol.TryParseMessage(ref data, binder, out var hubMessage)); + + var completion = Assert.IsType(hubMessage); + Assert.Null(completion.Result); + Assert.StartsWith("Error trying to deserialize result to Int32.", completion.Error); + Assert.Equal("xyz", completion.InvocationId); + } + public class ClientResultTestData { public string Name { get; } diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 30c06e594c7c..3458f6760a9f 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -319,6 +319,15 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, ReplaceArguments(descriptor, hubMethodInvocationMessage, isStreamCall, connection, scope, ref arguments, out cts); } + if (isStreamCall || isStreamResponse) + { + Debug.Assert(hub.Clients is HubCallerClients); + // Streaming invocations aren't involved with the semaphore. + // Setting the semaphore released flag avoids potential client result calls from the streaming hub method + // releasing the semaphore which would cause a SemaphoreFullException. + ((HubCallerClients)hub.Clients).TrySetSemaphoreReleased(); + } + if (isStreamResponse) { _ = StreamAsync(hubMethodInvocationMessage.InvocationId!, connection, arguments, scope, hubActivator, hub, cts, hubMethodInvocationMessage, descriptor); @@ -404,7 +413,7 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, { if (hub?.Clients is HubCallerClients hubCallerClients) { - wasSemaphoreReleased = Interlocked.CompareExchange(ref hubCallerClients.ShouldReleaseSemaphore, 0, 1) == 0; + wasSemaphoreReleased = !hubCallerClients.TrySetSemaphoreReleased(); } await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope); } diff --git a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs index 8e6ec0fa0dc9..15d2c036602e 100644 --- a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs +++ b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs @@ -9,7 +9,7 @@ internal sealed class HubCallerClients : IHubCallerClients private readonly IHubClients _hubClients; internal readonly ChannelBasedSemaphore _parallelInvokes; - internal int ShouldReleaseSemaphore = 1; + private int _shouldReleaseSemaphore = 1; // Client results don't work in OnConnectedAsync // This property is set by the hub dispatcher when those methods are being called @@ -90,6 +90,12 @@ public IClientProxy Users(IReadOnlyList userIds) return _hubClients.Users(userIds); } + // false if semaphore is being released by another caller, true if you own releasing the semaphore + internal bool TrySetSemaphoreReleased() + { + return Interlocked.CompareExchange(ref _shouldReleaseSemaphore, 0, 1) == 1; + } + private sealed class NoInvokeSingleClientProxy : ISingleClientProxy { private readonly ISingleClientProxy _proxy; @@ -125,9 +131,9 @@ public async Task InvokeCoreAsync(string method, object?[] args, Cancellat { // Releases the Channel that is blocking pending invokes, which in turn can block the receive loop. // Because we are waiting for a result from the client we need to let the receive loop run otherwise we'll be blocked forever - var value = Interlocked.CompareExchange(ref _hubCallerClients.ShouldReleaseSemaphore, 0, 1); + var value = _hubCallerClients.TrySetSemaphoreReleased(); // Only release once, and we set ShouldReleaseSemaphore to 0 so the DefaultHubDispatcher knows not to call Release again - if (value == 1) + if (value) { _hubCallerClients._parallelInvokes.Release(); } diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index dc4ad919292e..b49e22583229 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -356,6 +356,18 @@ public void BackgroundClientResult(TcsService tcsService) } }); } + + public async Task GetClientResultWithStream(ChannelReader channelReader) + { + var sum = await Clients.Caller.InvokeAsync("Sum", 1, cancellationToken: default); + return sum; + } + + public async IAsyncEnumerable StreamWithClientResult() + { + var sum = await Clients.Caller.InvokeAsync("Sum", 1, cancellationToken: default); + yield return sum; + } } internal class SelfRef diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs index 1320c674d622..3b30eb6f6b10 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs @@ -4,6 +4,7 @@ using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.DependencyInjection; +using Moq; namespace Microsoft.AspNetCore.SignalR.Tests; @@ -440,6 +441,75 @@ public async Task CanCancelClientResultsWithIHubContext() } } + [Fact] + public async Task ClientResultInUploadStreamingMethodWorks() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + var invocationId = await client.BeginUploadStreamAsync("1", nameof(MethodHub.GetClientResultWithStream), new[] { "id" }, Array.Empty()).DefaultTimeout(); + + // Hub asks client for a result, this is an invocation message with an ID + var invocationMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocationMessage.InvocationId); + var res = 4 + ((long)invocationMessage.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage.InvocationId, res)).DefaultTimeout(); + + var completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(5L, completion.Result); + Assert.Equal(invocationId, completion.InvocationId); + + // Make sure we can still do a Hub invocation and that the semaphore state didn't get messed up + var completionMessage = await client.InvokeAsync(nameof(MethodHub.ValueMethod)).DefaultTimeout(); + Assert.Equal(43L, completionMessage.Result); + } + } + } + + [Fact] + public async Task ClientResultInStreamingMethodWorks() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + var invocationBinder = new Mock(); + invocationBinder.Setup(b => b.GetStreamItemType(It.IsAny())).Returns(typeof(int)); + invocationBinder.Setup(b => b.GetParameterTypes(It.IsAny())).Returns(new[] { typeof(int) }); + invocationBinder.Setup(b => b.GetReturnType(It.IsAny())).Returns(typeof(int)); + using (var client = new TestClient(invocationBinder: invocationBinder.Object)) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + var invocationId = await client.SendStreamInvocationAsync(nameof(MethodHub.StreamWithClientResult)).DefaultTimeout(); + + // Hub asks client for a result, this is an invocation message with an ID + var invocationMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocationMessage.InvocationId); + var res = 4 + ((int)invocationMessage.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage.InvocationId, res)).DefaultTimeout(); + + var streamItem = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(5, streamItem.Item); + Assert.Equal(invocationId, streamItem.InvocationId); + + var completionMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(invocationId, completionMessage.InvocationId); + + // Make sure we can still do a Hub invocation and that the semaphore state didn't get messed up + completionMessage = await client.InvokeAsync(nameof(MethodHub.ValueMethod)).DefaultTimeout(); + Assert.Equal(43, completionMessage.Result); + } + } + } + private class GetClientResultTwoWaysInvocationBinder : IInvocationBinder { public IReadOnlyList GetParameterTypes(string methodName) => new[] { typeof(int) }; diff --git a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs index 3e26bc244f58..610a89be138f 100644 --- a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs +++ b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs @@ -626,4 +626,35 @@ public async Task ConnectionDoesNotExist_FailsInvokeConnectionAsync() var ex = await Assert.ThrowsAsync(() => manager1.InvokeConnectionAsync("1234", "Result", new object[] { "test" }, cancellationToken: default)).DefaultTimeout(); Assert.Equal("Connection '1234' does not exist.", ex.Message); } + + /// + /// Specification test for SignalR HubLifetimeManager. + /// + /// A representing the asynchronous completion of the test. + [Fact] + public async Task ClientReturnResultAcrossServersWithWrongReturnedTypeErrors() + { + var backplane = CreateBackplane(); + var manager1 = CreateNewHubLifetimeManager(backplane); + var manager2 = CreateNewHubLifetimeManager(backplane); + + using (var client1 = new TestClient()) + { + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + + await manager1.OnConnectedAsync(connection1).DefaultTimeout(); + + // Server2 asks for a result from client1 on Server1 + var resultTask = manager2.InvokeConnectionAsync(connection1.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default); + var invocation = Assert.IsType(await client1.ReadAsync().DefaultTimeout()); + Assert.NotNull(invocation.InvocationId); + Assert.Equal("test", invocation.Arguments[0]); + + // Server1 gets the result from client1 and forwards to Server2 + await manager1.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation.InvocationId, "wrong type")).DefaultTimeout(); + + var ex = await Assert.ThrowsAsync(() => resultTask).DefaultTimeout(); + Assert.StartsWith("Error trying to deserialize result to Int32.", ex.Message); + } + } } diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisLog.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisLog.cs index 5df4804c52f4..1fb7b35e8a6c 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisLog.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisLog.cs @@ -61,6 +61,9 @@ public static void ConnectingToEndpoints(ILogger logger, EndPointCollection endp [LoggerMessage(14, LogLevel.Error, "Error connecting to Redis.", EventName = "ErrorConnecting")] public static partial void ErrorConnecting(ILogger logger, Exception ex); + [LoggerMessage(15, LogLevel.Warning, "Error parsing client result with protocol {HubProtocol}.", EventName = "ErrorParsingResult")] + public static partial void ErrorParsingResult(ILogger logger, string hubProtocol, Exception? ex); + // This isn't DefineMessage-based because it's just the simple TextWriter logging from ConnectionMultiplexer public static void ConnectionMultiplexerMessage(ILogger logger, string? message) { diff --git a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs index 6748720a7e10..84c9e5d41af0 100644 --- a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs +++ b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs @@ -695,15 +695,73 @@ private async Task SubscribeToReturnResultsAsync() } var ros = completion.CompletionMessage; - var parseSuccess = protocol.TryParseMessage(ref ros, _clientResultsManager, out var hubMessage); - Debug.Assert(parseSuccess); + HubMessage? hubMessage = null; + bool retryForError = false; + try + { + var parseSuccess = protocol.TryParseMessage(ref ros, _clientResultsManager, out hubMessage); + retryForError = !parseSuccess; + } + catch + { + // Client returned wrong type? Or just an error from the HubProtocol, let's try with RawResult as the type and see if that works + retryForError = true; + } + + if (retryForError) + { + try + { + ros = completion.CompletionMessage; + // if this works then we know there was an error with the type the client returned, we'll replace the CompletionMessage below and provide an error to the application code + if (!protocol.TryParseMessage(ref ros, FakeInvocationBinder.Instance, out hubMessage)) + { + RedisLog.ErrorParsingResult(_logger, completion.ProtocolName, null); + return; + } + } + // Exceptions here would mean the HubProtocol implementation very likely has a bug, the other server has already deserialized the message (with RawResult) so it should be deserializable + // We don't know the InvocationId, we should let the application developer know and potentially surface the issue to the HubProtocol implementor + catch (Exception ex) + { + RedisLog.ErrorParsingResult(_logger, completion.ProtocolName, ex); + return; + } + } var invocationInfo = _clientResultsManager.RemoveInvocation(((CompletionMessage)hubMessage!).InvocationId!); + if (retryForError && invocationInfo is not null) + { + hubMessage = CompletionMessage.WithError(((CompletionMessage)hubMessage!).InvocationId!, $"Client result wasn't deserializable to {invocationInfo?.Type.Name}."); + } + invocationInfo?.Completion(invocationInfo?.Tcs!, (CompletionMessage)hubMessage!); }); } + private class FakeInvocationBinder : IInvocationBinder + { + public static readonly FakeInvocationBinder Instance = new FakeInvocationBinder(); + + private FakeInvocationBinder() { } + + public IReadOnlyList GetParameterTypes(string methodName) + { + throw new NotImplementedException(); + } + + public Type GetReturnType(string invocationId) + { + return typeof(RawResult); + } + + public Type GetStreamItemType(string streamId) + { + throw new NotImplementedException(); + } + } + private async Task EnsureRedisServerConnection() { if (_redisServerConnection == null)