Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,22 @@ namespace TUnit.Mocks.Generated

public void Process(global::System.ReadOnlySpan<byte> data)
{
_engine.HandleCall(0, "Process", global::System.Array.Empty<object?>());
#if NET9_0_OR_GREATER
var __args = new object?[] { null };
#else
var __args = global::System.Array.Empty<object?>();
#endif
_engine.HandleCall(0, "Process", __args);
}

public int Parse(global::System.ReadOnlySpan<char> text)
{
return _engine.HandleCallWithReturn<int>(1, "Parse", global::System.Array.Empty<object?>(), default);
#if NET9_0_OR_GREATER
var __args = new object?[] { null };
#else
var __args = global::System.Array.Empty<object?>();
#endif
return _engine.HandleCallWithReturn<int>(1, "Parse", __args, default);
}

public string GetName()
Expand All @@ -72,17 +82,33 @@ namespace TUnit.Mocks.Generated
{
public static class IBufferProcessor_MockMemberExtensions
{
#if NET9_0_OR_GREATER
public static global::TUnit.Mocks.VoidMockMethodCall Process(this global::TUnit.Mocks.Mock<global::IBufferProcessor> mock, global::TUnit.Mocks.Arguments.RefStructArg<global::System.ReadOnlySpan<byte>> data)
{
var matchers = new global::TUnit.Mocks.Arguments.IArgumentMatcher[] { data.Matcher };
return new global::TUnit.Mocks.VoidMockMethodCall(mock.Engine, 0, "Process", matchers);
}
#else
public static global::TUnit.Mocks.VoidMockMethodCall Process(this global::TUnit.Mocks.Mock<global::IBufferProcessor> mock)
{
var matchers = global::System.Array.Empty<global::TUnit.Mocks.Arguments.IArgumentMatcher>();
return new global::TUnit.Mocks.VoidMockMethodCall(mock.Engine, 0, "Process", matchers);
}
#endif

#if NET9_0_OR_GREATER
public static global::TUnit.Mocks.MockMethodCall<int> Parse(this global::TUnit.Mocks.Mock<global::IBufferProcessor> mock, global::TUnit.Mocks.Arguments.RefStructArg<global::System.ReadOnlySpan<char>> text)
{
var matchers = new global::TUnit.Mocks.Arguments.IArgumentMatcher[] { text.Matcher };
return new global::TUnit.Mocks.MockMethodCall<int>(mock.Engine, 1, "Parse", matchers);
}
#else
public static global::TUnit.Mocks.MockMethodCall<int> Parse(this global::TUnit.Mocks.Mock<global::IBufferProcessor> mock)
{
var matchers = global::System.Array.Empty<global::TUnit.Mocks.Arguments.IArgumentMatcher>();
return new global::TUnit.Mocks.MockMethodCall<int>(mock.Engine, 1, "Parse", matchers);
}
#endif

public static global::TUnit.Mocks.MockMethodCall<string> GetName(this global::TUnit.Mocks.Mock<global::IBufferProcessor> mock)
{
Expand Down
66 changes: 57 additions & 9 deletions TUnit.Mocks.SourceGenerator/Builders/MockImplBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,21 @@ private static void GenerateWrapMethodBody(CodeWriter writer, MockMemberModel me
}
}

var argsArray = GetArgsArrayExpression(method);
string argsArray;
if (HasRefStructParams(method))
{
// Emit #if block so the variable is defined under both branches
writer.AppendLine("#if NET9_0_OR_GREATER");
writer.AppendLine($"var __args = {GetArgsArrayExpression(method, true)};");
writer.AppendLine("#else");
writer.AppendLine($"var __args = {GetArgsArrayExpression(method, false)};");
writer.AppendLine("#endif");
argsArray = "__args";
}
else
{
argsArray = GetArgsArrayExpression(method, false);
}
var argPassList = GetArgPassList(method);

if (method.IsVoid && !method.IsAsync)
Expand Down Expand Up @@ -461,7 +475,20 @@ private static void GeneratePartialMethodBody(CodeWriter writer, MockMemberModel
}
}

var argsArray = GetArgsArrayExpression(method);
string argsArray;
if (HasRefStructParams(method))
{
writer.AppendLine("#if NET9_0_OR_GREATER");
writer.AppendLine($"var __args = {GetArgsArrayExpression(method, true)};");
writer.AppendLine("#else");
writer.AppendLine($"var __args = {GetArgsArrayExpression(method, false)};");
writer.AppendLine("#endif");
argsArray = "__args";
}
else
{
argsArray = GetArgsArrayExpression(method, false);
}
var argPassList = GetArgPassList(method);

if (method.IsVoid && !method.IsAsync)
Expand Down Expand Up @@ -551,7 +578,20 @@ private static void GenerateEngineDispatchBody(CodeWriter writer, MockMemberMode
}
}

var argsArray = GetArgsArrayExpression(method);
string argsArray;
if (HasRefStructParams(method))
{
writer.AppendLine("#if NET9_0_OR_GREATER");
writer.AppendLine($"var __args = {GetArgsArrayExpression(method, true)};");
writer.AppendLine("#else");
writer.AppendLine($"var __args = {GetArgsArrayExpression(method, false)};");
writer.AppendLine("#endif");
argsArray = "__args";
}
else
{
argsArray = GetArgsArrayExpression(method, false);
}

var hasOutRef = HasOutRefParams(method);

Expand Down Expand Up @@ -955,14 +995,22 @@ private static void EmitOutRefReadback(CodeWriter writer, MockMemberModel method
}
}

private static string GetArgsArrayExpression(MockMemberModel method)
private static bool HasRefStructParams(MockMemberModel method)
=> method.Parameters.Any(p => p.IsRefStruct && p.Direction != ParameterDirection.Out);

private static string GetArgsArrayExpression(MockMemberModel method, bool includeRefStructSentinels)
{
// Only include non-out, non-ref-struct parameters in args array
// (ref structs cannot be boxed into object?[])
var matchableParams = method.Parameters.Where(p => p.Direction != ParameterDirection.Out && !p.IsRefStruct).ToList();
var nonOutParams = method.Parameters.Where(p => p.Direction != ParameterDirection.Out).ToList();
if (includeRefStructSentinels)
{
if (nonOutParams.Count == 0) return "global::System.Array.Empty<object?>()";
var args = string.Join(", ", nonOutParams.Select(p => p.IsRefStruct ? "null" : p.Name));
return $"new object?[] {{ {args} }}";
}
var matchableParams = nonOutParams.Where(p => !p.IsRefStruct).ToList();
if (matchableParams.Count == 0) return "global::System.Array.Empty<object?>()";
var args = string.Join(", ", matchableParams.Select(p => p.Name));
return $"new object?[] {{ {args} }}";
var argsStr = string.Join(", ", matchableParams.Select(p => p.Name));
return $"new object?[] {{ {argsStr} }}";
}

/// <summary>
Expand Down
Loading
Loading