Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor(mocks): address PR review feedback for #5434 fix
- Move `keepOutParams` from a 5-signature threaded bool into a
  `KeepOutParamsInExtensionSignature` property on MockMemberModel.
  The disambiguation decision is now baked into the model in Build()
  via `ApplyOutDisambiguation`, so emit code reads the flag from the
  member rather than receiving it as a parameter.
- Include parameter direction in the disambiguation collision key so
  overloads differing only by ref vs by-value are not falsely flagged.
- Document why `default!` is used for out-param initialization in
  EmitOutParamDefaults (the extension never observes the value).
- Add a comment in EmitSingleFuncOverload noting that disambiguated
  overloads require `out _` at the call site.
- Add behavioral tests that actually invoke the disambiguated
  GenerateSasUri(..., out _) overload and the generic GetEntity<T>
  override path, so a future regression in the emitted bodies is
  caught.
  • Loading branch information
thomhurst committed Apr 7, 2026
commit 2a828f9d4f210640c7a6b2c373aa89a0f44880da
80 changes: 49 additions & 31 deletions TUnit.Mocks.SourceGenerator/Builders/MockMembersBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,14 @@ public static string Build(MockTypeModel model)
// signature to avoid CS0111 collisions. A method needs disambiguation when
// some other method on the model shares the same name AND the same
// matchable-parameter signature (i.e. parameters excluding out).
var needsOutDisambiguation = ComputeOutDisambiguationSet(model.Methods);
var methodsWithDisambiguation = ApplyOutDisambiguation(model.Methods);

// Methods
foreach (var method in model.Methods)
foreach (var method in methodsWithDisambiguation)
{
if (!firstMember) writer.AppendLine();
firstMember = false;
GenerateMemberMethod(writer, method, model, safeName,
keepOutParams: needsOutDisambiguation.Contains(method.MemberId));
GenerateMemberMethod(writer, method, model, safeName);
}

// Properties -- extension properties via C# 14 extension blocks
Expand Down Expand Up @@ -91,9 +90,14 @@ public static string Build(MockTypeModel model)
return writer.ToString();
}

private static void EmitOutParamDefaults(CodeWriter writer, MockMemberModel method, bool keepOutParams)
private static void EmitOutParamDefaults(CodeWriter writer, MockMemberModel method)
{
if (!keepOutParams) return;
if (!method.KeepOutParamsInExtensionSignature) return;
// Out params are assigned `default!` because the extension method never actually invokes
// the mocked method — it only *configures* a setup. The out value is never observed by
// caller code: this setup-configuration call returns a MockMethodCall, not the mocked
// result. For reference types this suppresses the CS8625 nullable warning on an unused
// assignment that exists solely to satisfy the `out` contract.
foreach (var op in method.Parameters.Where(p => p.Direction == ParameterDirection.Out))
{
writer.AppendLine($"{op.Name} = default!;");
Expand Down Expand Up @@ -574,37 +578,43 @@ private static string CastArg(MockParameterModel p, int index)
return $"({p.FullyQualifiedType})args[{index}]{bang}";
}

private static void GenerateMemberMethod(CodeWriter writer, MockMemberModel method, MockTypeModel model, string safeName, bool keepOutParams)
private static void GenerateMemberMethod(CodeWriter writer, MockMemberModel method, MockTypeModel model, string safeName)
{
if (method.HasRefStructParams)
{
writer.AppendLine("#if NET9_0_OR_GREATER");
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: true, keepOutParams);
EmitFuncOverloads(writer, method, model, safeName, includeRefStructArgs: true, keepOutParams);
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: true);
EmitFuncOverloads(writer, method, model, safeName, includeRefStructArgs: true);
writer.AppendLine("#else");
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: false, keepOutParams);
EmitFuncOverloads(writer, method, model, safeName, includeRefStructArgs: false, keepOutParams);
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: false);
EmitFuncOverloads(writer, method, model, safeName, includeRefStructArgs: false);
writer.AppendLine("#endif");
}
else
{
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: false, keepOutParams);
EmitFuncOverloads(writer, method, model, safeName, includeRefStructArgs: false, keepOutParams);
EmitMemberMethodBody(writer, method, model, safeName, includeRefStructArgs: false);
EmitFuncOverloads(writer, method, model, safeName, includeRefStructArgs: false);
}
}

private static HashSet<int> ComputeOutDisambiguationSet(EquatableArray<MockMemberModel> methods)
/// <summary>
/// Returns the input methods, with <see cref="MockMemberModel.KeepOutParamsInExtensionSignature"/>
/// set to <c>true</c> on any method whose generated extension method would otherwise collide with
/// another overload. Methods are grouped by (name, type-arity, matchable-parameter signature);
/// any group with more than one entry causes its <c>out</c>-bearing members to be flagged.
/// The matchable-parameter signature includes parameter direction (ref/in/by-value) so that
/// overloads differing only by direction (e.g. <c>Foo(int)</c> vs <c>Foo(ref int)</c>) are not
/// treated as collisions.
/// </summary>
private static IEnumerable<MockMemberModel> ApplyOutDisambiguation(EquatableArray<MockMemberModel> methods)
{
// Group methods by (name, matchable-parameter signature). Any group with >1 entry
// contains methods that would otherwise emit colliding extension overloads — flag
// every member of such a group whose original method has out parameters.
var result = new HashSet<int>();
var flagged = new HashSet<int>();
var byKey = new Dictionary<string, List<MockMemberModel>>(System.StringComparer.Ordinal);
foreach (var m in methods)
{
var matchable = string.Join(",", m.Parameters
.Where(p => p.Direction != ParameterDirection.Out)
.Select(p => p.FullyQualifiedType));
.Select(p => $"{p.Direction}:{p.FullyQualifiedType}"));
var typeArity = m.TypeParameters.Length;
var key = $"{m.Name}`{typeArity}({matchable})";
if (!byKey.TryGetValue(key, out var list))
Expand All @@ -621,11 +631,18 @@ private static HashSet<int> ComputeOutDisambiguationSet(EquatableArray<MockMembe
{
if (m.Parameters.Any(p => p.Direction == ParameterDirection.Out))
{
result.Add(m.MemberId);
flagged.Add(m.MemberId);
}
}
}
return result;

if (flagged.Count == 0)
{
return methods;
}
return methods.Select(m => flagged.Contains(m.MemberId)
? m with { KeepOutParamsInExtensionSignature = true }
: m);
}

private static (bool UseTypedWrapper, string ReturnType, string SetupReturnType) GetReturnTypeInfo(
Expand Down Expand Up @@ -659,11 +676,11 @@ private static (bool UseTypedWrapper, string ReturnType, string SetupReturnType)
return (useTypedWrapper, returnType, setupReturnType);
}

private static void EmitMemberMethodBody(CodeWriter writer, MockMemberModel method, MockTypeModel model, string safeName, bool includeRefStructArgs, bool keepOutParams)
private static void EmitMemberMethodBody(CodeWriter writer, MockMemberModel method, MockTypeModel model, string safeName, bool includeRefStructArgs)
{
var (useTypedWrapper, returnType, setupReturnType) = GetReturnTypeInfo(method, model, safeName);

var paramList = GetArgParameterList(method, includeRefStructArgs, keepOutParams);
var paramList = GetArgParameterList(method, includeRefStructArgs);
var typeParams = MockImplBuilder.GetTypeParameterList(method);
var constraints = MockImplBuilder.GetConstraintClauses(method);

Expand All @@ -684,7 +701,7 @@ private static void EmitMemberMethodBody(CodeWriter writer, MockMemberModel meth

using (writer.Block($"public static {returnType} {safeMemberName}{typeParams}({fullParamList}){constraints}"))
{
EmitOutParamDefaults(writer, method, keepOutParams);
EmitOutParamDefaults(writer, method);

// Build matchers array
var matchableParams = includeRefStructArgs
Expand Down Expand Up @@ -737,7 +754,7 @@ private static List<int> GetFuncEligibleParamIndices(MockMemberModel method)
}

private static void EmitFuncOverloads(CodeWriter writer, MockMemberModel method, MockTypeModel model,
string safeName, bool includeRefStructArgs, bool keepOutParams)
string safeName, bool includeRefStructArgs)
{
var eligible = GetFuncEligibleParamIndices(method);
if (eligible.Count == 0 || eligible.Count > MaxFuncOverloadParams) return;
Expand All @@ -746,12 +763,12 @@ private static void EmitFuncOverloads(CodeWriter writer, MockMemberModel method,
for (int mask = 1; mask <= totalMasks; mask++)
{
writer.AppendLine();
EmitSingleFuncOverload(writer, method, model, safeName, eligible, mask, includeRefStructArgs, keepOutParams);
EmitSingleFuncOverload(writer, method, model, safeName, eligible, mask, includeRefStructArgs);
}
}

private static void EmitSingleFuncOverload(CodeWriter writer, MockMemberModel method, MockTypeModel model,
string safeName, List<int> eligibleIndices, int funcMask, bool includeRefStructArgs, bool keepOutParams)
string safeName, List<int> eligibleIndices, int funcMask, bool includeRefStructArgs)
{
// Determine which parameter indices use Func<T, bool>
var funcIndices = new HashSet<int>();
Expand All @@ -773,7 +790,8 @@ private static void EmitSingleFuncOverload(CodeWriter writer, MockMemberModel me
if (p.Direction == ParameterDirection.Out)
{
// Keep out params only when needed to disambiguate colliding overloads.
if (keepOutParams)
// Callers of the disambiguated overload must write `out _` at the call site.
if (method.KeepOutParamsInExtensionSignature)
{
paramParts.Add($"out {p.FullyQualifiedType} {p.Name}");
}
Expand Down Expand Up @@ -818,7 +836,7 @@ private static void EmitSingleFuncOverload(CodeWriter writer, MockMemberModel me

using (writer.Block($"public static {returnType} {safeMemberName}{typeParams}({fullParamList}){constraints}"))
{
EmitOutParamDefaults(writer, method, keepOutParams);
EmitOutParamDefaults(writer, method);

// Convert Func params to Arg<T> via implicit conversion
foreach (var idx in funcIndices.OrderBy(i => i))
Expand Down Expand Up @@ -934,7 +952,7 @@ private static void GenerateRaiseExtensionMethods(CodeWriter writer, MockTypeMod
}
}

private static string GetArgParameterList(MockMemberModel method, bool includeRefStructArgs, bool keepOutParams)
private static string GetArgParameterList(MockMemberModel method, bool includeRefStructArgs)
{
var parts = new List<string>();
foreach (var p in method.Parameters)
Expand All @@ -946,7 +964,7 @@ private static string GetArgParameterList(MockMemberModel method, bool includeRe
// the same matchable-parameter signature (e.g. GenerateSasUri(perms, expires)
// vs GenerateSasUri(perms, expires, out string)), we MUST keep the out param
// in the signature, otherwise CS0111 fires on the generated extensions.
if (keepOutParams)
if (method.KeepOutParamsInExtensionSignature)
{
parts.Add($"out {p.FullyQualifiedType} {p.Name}");
}
Expand Down
12 changes: 12 additions & 0 deletions TUnit.Mocks.SourceGenerator/Models/MockMemberModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ internal sealed record MockMemberModel : IEquatable<MockMemberModel>
/// </summary>
public bool IsReturnTypeStaticAbstractInterface { get; init; }

/// <summary>
/// True when this method's <c>out</c> parameters must be kept in the generated extension
/// method signature to avoid CS0111 collisions with a sibling overload that shares the same
/// matchable-parameter signature (e.g. <c>BlobClient.GenerateSasUri(perms, expires)</c> vs
/// <c>GenerateSasUri(perms, expires, out string stringToSign)</c>). When true, callers must
/// pass <c>out _</c> at the call site for the disambiguated overload.
/// Computed across the full method set in <see cref="Builders.MockMembersBuilder"/>.
/// </summary>
public bool KeepOutParamsInExtensionSignature { get; init; }

/// <summary>
/// For methods returning ReadOnlySpan&lt;T&gt; or Span&lt;T&gt;, the fully qualified element type.
/// Null for non-span return types. Used to support configurable span return values via array conversion.
Expand Down Expand Up @@ -86,6 +96,7 @@ public bool Equals(MockMemberModel? other)
&& IsRefStructReturn == other.IsRefStructReturn
&& IsStaticAbstract == other.IsStaticAbstract
&& IsReturnTypeStaticAbstractInterface == other.IsReturnTypeStaticAbstractInterface
&& KeepOutParamsInExtensionSignature == other.KeepOutParamsInExtensionSignature
&& SpanReturnElementType == other.SpanReturnElementType;
}

Expand All @@ -100,6 +111,7 @@ public override int GetHashCode()
hash = hash * 31 + Parameters.GetHashCode();
hash = hash * 31 + IsStaticAbstract.GetHashCode();
hash = hash * 31 + IsReturnTypeStaticAbstractInterface.GetHashCode();
hash = hash * 31 + KeepOutParamsInExtensionSignature.GetHashCode();
hash = hash * 31 + (ExplicitInterfaceName?.GetHashCode() ?? 0);
hash = hash * 31 + (DeclaringInterfaceName?.GetHashCode() ?? 0);
return hash;
Expand Down
26 changes: 25 additions & 1 deletion TUnit.Mocks.Tests/Issue5434Tests.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
using Azure.Data.Tables;
using Azure.Storage.Blobs;
using Azure.Storage.Sas;

namespace TUnit.Mocks.Tests;

// Reproduction for https://github.com/thomhurst/TUnit/issues/5434
// Reproduction and regression tests for https://github.com/thomhurst/TUnit/issues/5434
// BlobClient: CS0111 duplicate GenerateSasUri / GenerateUserDelegationSasUri members in generated extensions.
// TableClient: CS0411 type inference failures for generic methods (GetEntity<T>, GetEntityAsync<T>,
// GetEntityIfExists<T>, GetEntityIfExistsAsync<T>, Query<T>, QueryAsync<T>) in generated impl factory.
Expand All @@ -22,4 +23,27 @@ public void Can_Mock_TableClient()
var mock = Mock.Of<TableClient>(MockBehavior.Strict);
_ = mock.Object;
}

// Exercises the disambiguated overload that keeps `out string stringToSign` in its
// signature to distinguish it from GenerateSasUri(perms, expires). This call would not
// compile if `keepOutParams` disambiguation regressed.
[Test]
public void Can_Configure_BlobClient_GenerateSasUri_OutOverload()
{
var mock = Mock.Of<BlobClient>(MockBehavior.Loose);
_ = mock.GenerateSasUri(Arg.Any<BlobSasPermissions>(), Arg.Any<System.DateTimeOffset>(), out _);
}

// Exercises the generic-return-type override path. This would not compile if
// the base.GetEntity(...) call in the generated override was missing the <T> type argument.
[Test]
public void Can_Configure_TableClient_GetEntity_Generic()
{
var mock = Mock.Of<TableClient>(MockBehavior.Loose);
_ = mock.GetEntity<TableEntity>(
Arg.Any<string>(),
Arg.Any<string>(),
Arg.Any<System.Collections.Generic.IEnumerable<string>>(),
Arg.Any<System.Threading.CancellationToken>());
}
}
Loading