Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix missing default constraint on explicit interface implementation…
…s with unconstrained generics

When TUnit.Mocks generates explicit interface implementations for methods
with unconstrained generic type parameters using nullable annotations (T?),
the required `where T : default` constraint was missing, causing CS0453 and
CS0539 compilation errors.

Closes #5362
  • Loading branch information
thomhurst committed Apr 4, 2026
commit 8186712ce51d158cf400438855f0ccbced159a05
25 changes: 25 additions & 0 deletions TUnit.Mocks.SourceGenerator.Tests/MockGeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -677,4 +677,29 @@ void M()

return VerifyGeneratorOutput(source);
}

[Test]
public Task Interface_With_Unconstrained_Nullable_Generic()
{
var source = """
using System.Threading.Tasks;
using TUnit.Mocks;

public interface IFoo
{
Task<T?> DoSomethingAsync<T>();
T? GetValue<T>();
}

public class TestUsage
{
void M()
{
var mock = Mock.Of<IFoo>();
}
}
""";

return VerifyGeneratorOutput(source);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// <auto-generated/>
#nullable enable

namespace TUnit.Mocks.Generated
{
public sealed class IFooMock : global::TUnit.Mocks.Mock<global::IFoo>, global::IFoo
{
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
internal IFooMock(global::IFoo mockObject, global::TUnit.Mocks.MockEngine<global::IFoo> engine)
: base(mockObject, engine) { }

global::System.Threading.Tasks.Task<T?> global::IFoo.DoSomethingAsync<T>() where T : default => Object.DoSomethingAsync<T>();

T? global::IFoo.GetValue<T>() where T : default => Object.GetValue<T>();
}
}


// ===== FILE SEPARATOR =====

// <auto-generated/>
#nullable enable

namespace TUnit.Mocks.Generated
{
internal static class IFooMockFactory
{
[global::System.Runtime.CompilerServices.ModuleInitializer]
internal static void Register()
{
global::TUnit.Mocks.MockRegistry.RegisterFactory<global::IFoo>(Create);
}

internal static global::TUnit.Mocks.Mock<global::IFoo> Create(global::TUnit.Mocks.MockBehavior behavior, object[] constructorArgs)
{
if (constructorArgs.Length > 0) throw new global::System.ArgumentException($"Interface mock 'global::IFoo' does not support constructor arguments, but {constructorArgs.Length} were provided.");
var engine = new global::TUnit.Mocks.MockEngine<global::IFoo>(behavior);
var impl = new IFooMockImpl(engine);
engine.Raisable = impl;
var mock = new IFooMock(impl, engine);
return mock;
}
}
}


// ===== FILE SEPARATOR =====

// <auto-generated/>
#nullable enable

namespace TUnit.Mocks.Generated
{
internal sealed class IFooMockImpl : global::IFoo, global::TUnit.Mocks.IRaisable, global::TUnit.Mocks.IMockObject
{
private readonly global::TUnit.Mocks.MockEngine<global::IFoo> _engine;

[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
global::TUnit.Mocks.IMock? global::TUnit.Mocks.IMockObject.MockWrapper { get; set; }

internal IFooMockImpl(global::TUnit.Mocks.MockEngine<global::IFoo> engine)
{
_engine = engine;
}

public global::System.Threading.Tasks.Task<T?> DoSomethingAsync<T>()
{
try
{
var __result = _engine.HandleCallWithReturn<T?>(0, "DoSomethingAsync", global::System.Array.Empty<object?>(), default);
if (global::TUnit.Mocks.Setup.RawReturnContext.TryConsume(out var __rawAsync))
{
if (__rawAsync is global::System.Threading.Tasks.Task<T?> __typedAsync) return __typedAsync;
throw new global::System.InvalidOperationException($"ReturnsAsync: expected global::System.Threading.Tasks.Task<T?> but got {__rawAsync?.GetType().Name ?? "null"}");
}
return global::System.Threading.Tasks.Task.FromResult<T?>(__result);
}
catch (global::System.Exception __ex)
{
return global::System.Threading.Tasks.Task.FromException<T?>(__ex);
}
}

public T? GetValue<T>()
{
return _engine.HandleCallWithReturn<T?>(1, "GetValue", global::System.Array.Empty<object?>(), default);
}

[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
public void RaiseEvent(string eventName, object? args)
{
throw new global::System.InvalidOperationException($"No event named '{eventName}' exists on this mock.");
}
}
}


// ===== FILE SEPARATOR =====

// <auto-generated/>
#nullable enable

namespace TUnit.Mocks.Generated
{
public static class IFoo_MockMemberExtensions
{
public static global::TUnit.Mocks.MockMethodCall<T?> DoSomethingAsync<T>(this global::TUnit.Mocks.Mock<global::IFoo> mock)
{
var matchers = global::System.Array.Empty<global::TUnit.Mocks.Arguments.IArgumentMatcher>();
return new global::TUnit.Mocks.MockMethodCall<T?>(global::TUnit.Mocks.MockRegistry.GetEngine(mock), 0, "DoSomethingAsync", matchers);
}

public static global::TUnit.Mocks.MockMethodCall<T?> GetValue<T>(this global::TUnit.Mocks.Mock<global::IFoo> mock)
{
var matchers = global::System.Array.Empty<global::TUnit.Mocks.Arguments.IArgumentMatcher>();
return new global::TUnit.Mocks.MockMethodCall<T?>(global::TUnit.Mocks.MockRegistry.GetEngine(mock), 1, "GetValue", matchers);
}
}
}


// ===== FILE SEPARATOR =====

// <auto-generated/>
#nullable enable

namespace TUnit.Mocks
{
public static class IFoo_MockStaticExtension
{
extension(global::IFoo)
{
public static global::TUnit.Mocks.Generated.IFooMock Mock(global::TUnit.Mocks.MockBehavior behavior = global::TUnit.Mocks.MockBehavior.Loose)
{
return (global::TUnit.Mocks.Generated.IFooMock)global::TUnit.Mocks.Generated.IFooMockFactory.Create(behavior, []);
}
}
}
}


// ===== FILE SEPARATOR =====

// <auto-generated/>
#nullable enable

namespace TUnit.Mocks.Generated;
2 changes: 1 addition & 1 deletion TUnit.Mocks.SourceGenerator/Builders/MockBridgeBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ private static void GenerateStaticMethodDim(CodeWriter writer, MockMemberModel m
var signatureReturnType = (method.IsVoid && !method.IsAsync) ? "void" : method.ReturnType;
var paramList = MockImplBuilder.GetParameterList(method);
var typeParams = MockImplBuilder.GetTypeParameterList(method);
var constraints = MockImplBuilder.GetConstraintClauses(method);
var constraints = MockImplBuilder.GetConstraintClauses(method, forExplicitImplementation: true);

using (writer.Block($"static {signatureReturnType} {method.ExplicitInterfaceName}.{method.Name}{typeParams}({paramList}){constraints}"))
{
Expand Down
10 changes: 7 additions & 3 deletions TUnit.Mocks.SourceGenerator/Builders/MockImplBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1110,10 +1110,10 @@ private static string FormatTypeParameterList(EquatableArray<MockTypeParameterMo

// Only for non-override declarations (interface impls, extension methods).
// C# prohibits restating constraints on override methods (CS0460).
internal static string GetConstraintClauses(MockMemberModel method) =>
FormatConstraintClauses(method.TypeParameters);
internal static string GetConstraintClauses(MockMemberModel method, bool forExplicitImplementation = false) =>
FormatConstraintClauses(method.TypeParameters, forExplicitImplementation);

private static string FormatConstraintClauses(EquatableArray<MockTypeParameterModel> typeParameters)
private static string FormatConstraintClauses(EquatableArray<MockTypeParameterModel> typeParameters, bool forExplicitImplementation = false)
{
var clauses = new List<string>();
foreach (var tp in typeParameters)
Expand All @@ -1122,6 +1122,10 @@ private static string FormatConstraintClauses(EquatableArray<MockTypeParameterMo
{
clauses.Add($"where {tp.Name} : {tp.Constraints}");
}
else if (forExplicitImplementation && tp.NeedsDefaultConstraint)
{
clauses.Add($"where {tp.Name} : default");
}
}
return clauses.Count > 0 ? " " + string.Join(' ', clauses) : "";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private static void GenerateMethodForwarding(CodeWriter writer, MockMemberModel
var interfaceName = method.ExplicitInterfaceName ?? method.DeclaringInterfaceName ?? model.FullyQualifiedName;
var paramList = MockImplBuilder.GetParameterList(method);
var typeParams = MockImplBuilder.GetTypeParameterList(method);
var constraints = MockImplBuilder.GetConstraintClauses(method);
var constraints = MockImplBuilder.GetConstraintClauses(method, forExplicitImplementation: true);
var argPassList = MockImplBuilder.GetArgPassList(method);
var returnType = (method.IsVoid && !method.IsAsync) ? "void" : method.ReturnType;

Expand Down
3 changes: 2 additions & 1 deletion TUnit.Mocks.SourceGenerator/Discovery/MemberDiscovery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,8 @@ private static MockMemberModel CreateMethodModel(IMethodSymbol method, ref int m
method.TypeParameters.Select(tp => new MockTypeParameterModel
{
Name = tp.Name,
Constraints = tp.GetGenericConstraints()
Constraints = tp.GetGenericConstraints(),
NeedsDefaultConstraint = tp.IsUnconstrainedWithNullableUsage(method)
}).ToImmutableArray()
),
ExplicitInterfaceName = explicitInterfaceName,
Expand Down
59 changes: 58 additions & 1 deletion TUnit.Mocks.SourceGenerator/Extensions/MethodSymbolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,19 @@ public static ParameterDirection GetParameterDirection(this IParameterSymbol par
};
}

private static bool IsUnconstrained(this ITypeParameterSymbol typeParam) =>
!typeParam.HasReferenceTypeConstraint &&
!typeParam.HasValueTypeConstraint &&
!typeParam.HasUnmanagedTypeConstraint &&
!typeParam.HasNotNullConstraint &&
typeParam.ConstraintTypes.Length == 0 &&
!typeParam.HasConstructorConstraint;

public static string GetGenericConstraints(this ITypeParameterSymbol typeParam)
{
if (typeParam.IsUnconstrained())
return "";

var constraints = new List<string>();

if (typeParam.HasReferenceTypeConstraint)
Expand All @@ -40,7 +51,53 @@ public static string GetGenericConstraints(this ITypeParameterSymbol typeParam)
if (typeParam.HasConstructorConstraint)
constraints.Add("new()");

return constraints.Count > 0 ? string.Join(", ", constraints) : "";
return string.Join(", ", constraints);
}

public static bool IsUnconstrainedWithNullableUsage(this ITypeParameterSymbol typeParam, IMethodSymbol method)
{
if (!typeParam.IsUnconstrained())
{
return false;
}

if (HasNullableTypeParameter(method.ReturnType, typeParam))
return true;

foreach (var param in method.Parameters)
{
if (HasNullableTypeParameter(param.Type, typeParam))
return true;
}

return false;
}

private static bool HasNullableTypeParameter(ITypeSymbol type, ITypeParameterSymbol typeParam)
{
if (type is ITypeParameterSymbol tp &&
SymbolEqualityComparer.Default.Equals(tp.OriginalDefinition, typeParam.OriginalDefinition) &&
tp.NullableAnnotation == NullableAnnotation.Annotated)
{
return true;
}

if (type is INamedTypeSymbol named)
{
foreach (var arg in named.TypeArguments)
{
if (HasNullableTypeParameter(arg, typeParam))
return true;
}
}

if (type is IArrayTypeSymbol array)
{
if (HasNullableTypeParameter(array.ElementType, typeParam))
return true;
}

return false;
}

public static string GetParameterList(this IMethodSymbol method)
Expand Down
4 changes: 3 additions & 1 deletion TUnit.Mocks.SourceGenerator/Models/MockTypeParameterModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ internal sealed record MockTypeParameterModel : IEquatable<MockTypeParameterMode
{
public string Name { get; init; } = "";
public string Constraints { get; init; } = "";
public bool NeedsDefaultConstraint { get; init; }

public bool Equals(MockTypeParameterModel? other)
{
if (other is null) return false;
return Name == other.Name && Constraints == other.Constraints;
return Name == other.Name && Constraints == other.Constraints && NeedsDefaultConstraint == other.NeedsDefaultConstraint;
}

public override int GetHashCode()
Expand All @@ -20,6 +21,7 @@ public override int GetHashCode()
int hash = 17;
hash = hash * 31 + Name.GetHashCode();
hash = hash * 31 + Constraints.GetHashCode();
hash = hash * 31 + NeedsDefaultConstraint.GetHashCode();
return hash;
}
}
Expand Down
Loading