Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
<ItemGroup>
<PackageVersion Include="Aspire.Hosting.AppHost" Version="13.2.1" />
<PackageVersion Include="Aspire.Hosting.Testing" Version="13.2.1" />
<PackageVersion Include="Azure.Data.Tables" Version="12.11.0" />
<PackageVersion Include="Azure.Storage.Blobs" Version="12.27.0" />
<PackageVersion Include="AutoFixture" Version="4.18.1" />
<PackageVersion Include="BenchmarkDotNet" Version="0.15.8" />
<PackageVersion Include="BenchmarkDotNet.Annotations" Version="0.15.8" />
Expand Down
12 changes: 6 additions & 6 deletions TUnit.Mocks.SourceGenerator/Builders/MockImplBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel
writer.AppendLine("return;");
writer.DecreaseIndent();
writer.AppendLine("}");
writer.AppendLine($"base.{method.Name}({argPassList});");
writer.AppendLine($"base.{method.Name}{GetTypeParameterList(method)}({argPassList});");
}
else if (method.IsVoid && method.IsAsync)
{
Expand All @@ -589,7 +589,7 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel
}
writer.DecreaseIndent();
writer.AppendLine("}");
writer.AppendLine($"return base.{method.Name}({argPassList});");
writer.AppendLine($"return base.{method.Name}{GetTypeParameterList(method)}({argPassList});");
}
else if (method.IsAsync)
{
Expand Down Expand Up @@ -619,7 +619,7 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel
}
writer.DecreaseIndent();
writer.AppendLine("}");
writer.AppendLine($"return base.{method.Name}({argPassList});");
writer.AppendLine($"return base.{method.Name}{GetTypeParameterList(method)}({argPassList});");
}
else if (method.IsRefStructReturn)
{
Expand All @@ -638,7 +638,7 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel
}
writer.DecreaseIndent();
writer.AppendLine("}");
writer.AppendLine($"return base.{method.Name}({argPassList});");
writer.AppendLine($"return base.{method.Name}{GetTypeParameterList(method)}({argPassList});");
}
else if (method.IsReturnTypeStaticAbstractInterface)
{
Expand All @@ -650,7 +650,7 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel
writer.AppendLine("return __result;");
writer.DecreaseIndent();
writer.AppendLine("}");
writer.AppendLine($"return base.{method.Name}({argPassList});");
writer.AppendLine($"return base.{method.Name}{GetTypeParameterList(method)}({argPassList});");
}
else
{
Expand All @@ -662,7 +662,7 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel
writer.AppendLine("return __result;");
writer.DecreaseIndent();
writer.AppendLine("}");
writer.AppendLine($"return base.{method.Name}({argPassList});");
writer.AppendLine($"return base.{method.Name}{GetTypeParameterList(method)}({argPassList});");
}
}

Expand Down
99 changes: 96 additions & 3 deletions TUnit.Mocks.SourceGenerator/Builders/MockMembersBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,14 @@ public static string Build(MockTypeModel model)
{
bool firstMember = true;

// Pre-compute which methods need their `out` parameters kept in the extension
// signature to avoid CS0111 collisions. A method needs disambiguation when
// some other method on the model shares the same name AND the same
// matchable-parameter signature (i.e. parameters excluding out).
var methodsWithDisambiguation = ApplyOutDisambiguation(model.Methods);

// Methods
foreach (var method in model.Methods)
foreach (var method in methodsWithDisambiguation)
{
if (!firstMember) writer.AppendLine();
firstMember = false;
Expand Down Expand Up @@ -84,6 +90,20 @@ public static string Build(MockTypeModel model)
return writer.ToString();
}

private static void EmitOutParamDefaults(CodeWriter writer, MockMemberModel method)
{
if (!method.KeepOutParamsInExtensionSignature) return;
// Out params are assigned `default!` because the extension method never actually invokes
// the mocked method — it only *configures* a setup. The out value is never observed by
// caller code: this setup-configuration call returns a MockMethodCall, not the mocked
// result. For reference types this suppresses the CS8625 nullable warning on an unused
// assignment that exists solely to satisfy the `out` contract.
foreach (var op in method.Parameters.Where(p => p.Direction == ParameterDirection.Out))
{
writer.AppendLine($"{op.Name} = default!;");
}
}

private static bool ShouldGenerateTypedWrapper(MockMemberModel method, bool hasEvents)
{
if (method.IsGenericMethod) return false;
Expand Down Expand Up @@ -577,6 +597,54 @@ private static void GenerateMemberMethod(CodeWriter writer, MockMemberModel meth
}
}

/// <summary>
/// Returns the input methods, with <see cref="MockMemberModel.KeepOutParamsInExtensionSignature"/>
/// set to <c>true</c> on any method whose generated extension method would otherwise collide with
/// another overload. Methods are grouped by (name, type-arity, matchable-parameter signature);
/// any group with more than one entry causes its <c>out</c>-bearing members to be flagged.
/// The matchable-parameter signature includes parameter direction (ref/in/by-value) so that
/// overloads differing only by direction (e.g. <c>Foo(int)</c> vs <c>Foo(ref int)</c>) are not
/// treated as collisions.
/// </summary>
private static IEnumerable<MockMemberModel> ApplyOutDisambiguation(EquatableArray<MockMemberModel> methods)
{
var flagged = new HashSet<int>();
var byKey = new Dictionary<string, List<MockMemberModel>>(System.StringComparer.Ordinal);
foreach (var m in methods)
{
var matchable = string.Join(",", m.Parameters
.Where(p => p.Direction != ParameterDirection.Out)
.Select(p => $"{p.Direction}:{p.FullyQualifiedType}"));
var typeArity = m.TypeParameters.Length;
var key = $"{m.Name}`{typeArity}({matchable})";
if (!byKey.TryGetValue(key, out var list))
{
list = new List<MockMemberModel>();
byKey[key] = list;
}
list.Add(m);
}
foreach (var group in byKey.Values)
{
if (group.Count < 2) continue;
foreach (var m in group)
{
if (m.Parameters.Any(p => p.Direction == ParameterDirection.Out))
{
flagged.Add(m.MemberId);
}
}
}

if (flagged.Count == 0)
{
return methods;
}
return methods.Select(m => flagged.Contains(m.MemberId)
? m with { KeepOutParamsInExtensionSignature = true }
: m);
}

private static (bool UseTypedWrapper, string ReturnType, string SetupReturnType) GetReturnTypeInfo(
MockMemberModel method, MockTypeModel model, string safeName)
{
Expand Down Expand Up @@ -633,6 +701,8 @@ private static void EmitMemberMethodBody(CodeWriter writer, MockMemberModel meth

using (writer.Block($"public static {returnType} {safeMemberName}{typeParams}({fullParamList}){constraints}"))
{
EmitOutParamDefaults(writer, method);

// Build matchers array
var matchableParams = includeRefStructArgs
? method.Parameters.Where(p => p.Direction != ParameterDirection.Out).ToList()
Expand Down Expand Up @@ -717,7 +787,16 @@ private static void EmitSingleFuncOverload(CodeWriter writer, MockMemberModel me
for (int i = 0; i < method.Parameters.Length; i++)
{
var p = method.Parameters[i];
if (p.Direction == ParameterDirection.Out) continue;
if (p.Direction == ParameterDirection.Out)
{
// Keep out params only when needed to disambiguate colliding overloads.
// Callers of the disambiguated overload must write `out _` at the call site.
if (method.KeepOutParamsInExtensionSignature)
{
paramParts.Add($"out {p.FullyQualifiedType} {p.Name}");
}
continue;
}

if (funcIndices.Contains(i))
{
Expand Down Expand Up @@ -757,6 +836,8 @@ private static void EmitSingleFuncOverload(CodeWriter writer, MockMemberModel me

using (writer.Block($"public static {returnType} {safeMemberName}{typeParams}({fullParamList}){constraints}"))
{
EmitOutParamDefaults(writer, method);

// Convert Func params to Arg<T> via implicit conversion
foreach (var idx in funcIndices.OrderBy(i => i))
{
Expand Down Expand Up @@ -876,7 +957,19 @@ private static string GetArgParameterList(MockMemberModel method, bool includeRe
var parts = new List<string>();
foreach (var p in method.Parameters)
{
if (p.Direction == ParameterDirection.Out) continue;
if (p.Direction == ParameterDirection.Out)
{
// Normally out params are omitted from the extension signature so callers
// don't have to write `out _`. But when another overload of this method has
// the same matchable-parameter signature (e.g. GenerateSasUri(perms, expires)
// vs GenerateSasUri(perms, expires, out string)), we MUST keep the out param
// in the signature, otherwise CS0111 fires on the generated extensions.
if (method.KeepOutParamsInExtensionSignature)
{
parts.Add($"out {p.FullyQualifiedType} {p.Name}");
}
continue;
}
if (p.IsRefStruct)
{
if (includeRefStructArgs)
Expand Down
13 changes: 13 additions & 0 deletions TUnit.Mocks.SourceGenerator/Models/MockMemberModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ internal sealed record MockMemberModel : IEquatable<MockMemberModel>
/// </summary>
public bool IsReturnTypeStaticAbstractInterface { get; init; }

/// <summary>
/// True when this method's <c>out</c> parameters must be kept in the generated extension
/// method signature to avoid CS0111 collisions with a sibling overload that shares the same
/// matchable-parameter signature (e.g. <c>BlobClient.GenerateSasUri(perms, expires)</c> vs
/// <c>GenerateSasUri(perms, expires, out string stringToSign)</c>). When true, callers must
/// pass <c>out _</c> at the call site for the disambiguated overload.
/// Computed transiently inside <see cref="Builders.MockMembersBuilder.Build"/> by examining
/// the full method set; never set on models flowing through the incremental pipeline.
/// Excluded from <see cref="Equals(MockMemberModel?)"/> / <see cref="GetHashCode"/> because
/// it is a derived per-build flag, not part of model identity.
/// </summary>
public bool KeepOutParamsInExtensionSignature { get; init; }

/// <summary>
/// For methods returning ReadOnlySpan&lt;T&gt; or Span&lt;T&gt;, the fully qualified element type.
/// Null for non-span return types. Used to support configurable span return values via array conversion.
Expand Down
49 changes: 49 additions & 0 deletions TUnit.Mocks.Tests/Issue5434Tests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using Azure.Data.Tables;
using Azure.Storage.Blobs;
using Azure.Storage.Sas;

namespace TUnit.Mocks.Tests;

// Reproduction and regression tests for https://github.com/thomhurst/TUnit/issues/5434
// BlobClient: CS0111 duplicate GenerateSasUri / GenerateUserDelegationSasUri members in generated extensions.
// TableClient: CS0411 type inference failures for generic methods (GetEntity<T>, GetEntityAsync<T>,
// GetEntityIfExists<T>, GetEntityIfExistsAsync<T>, Query<T>, QueryAsync<T>) in generated impl factory.
public class Issue5434Tests
{
[Test]
public void Can_Mock_BlobClient()
{
var mock = Mock.Of<BlobClient>(MockBehavior.Strict);
_ = mock.Object;
}

[Test]
public void Can_Mock_TableClient()
{
var mock = Mock.Of<TableClient>(MockBehavior.Strict);
_ = mock.Object;
}

// Exercises the disambiguated overload that keeps `out string stringToSign` in its
// signature to distinguish it from GenerateSasUri(perms, expires). This call would not
// compile if `keepOutParams` disambiguation regressed.
[Test]
public void Can_Configure_BlobClient_GenerateSasUri_OutOverload()
{
var mock = Mock.Of<BlobClient>(MockBehavior.Loose);
_ = mock.GenerateSasUri(Arg.Any<BlobSasPermissions>(), Arg.Any<System.DateTimeOffset>(), out _);
}

// Exercises the generic-return-type override path. This would not compile if
// the base.GetEntity(...) call in the generated override was missing the <T> type argument.
[Test]
public void Can_Configure_TableClient_GetEntity_Generic()
{
var mock = Mock.Of<TableClient>(MockBehavior.Loose);
_ = mock.GetEntity<TableEntity>(
Arg.Any<string>(),
Arg.Any<string>(),
Arg.Any<System.Collections.Generic.IEnumerable<string>>(),
Arg.Any<System.Threading.CancellationToken>());
}
}
2 changes: 2 additions & 0 deletions TUnit.Mocks.Tests/TUnit.Mocks.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.Data.Tables" />
<PackageReference Include="Azure.Storage.Blobs" />
<ProjectReference Include="..\TUnit.Mocks\TUnit.Mocks.csproj" />
<ProjectReference Include="..\TUnit.Mocks.Assertions\TUnit.Mocks.Assertions.csproj" />

Expand Down
Loading