Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Pass context to IMarshallingGenerator.AsArgument
Add placeholders for handling PreserveSig
  • Loading branch information
elinor-fung committed Sep 11, 2020
commit 6a48f8eadd7b0dabac411278f78c98c33cfddb6b
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ private void PrintGeneratedSource(
.WithBody(stub.StubCode);

// Create the DllImport declaration.
// [TODO] Don't include PreserveSig=false once that is handled by the generated stub
var dllImport = stub.DllImportDeclaration.AddAttributeLists(
AttributeList(
SingletonSeparatedList<AttributeSyntax>(dllImportAttr)));
Expand Down
32 changes: 26 additions & 6 deletions DllImportGenerator/DllImportGenerator/DllImportStub.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ public IEnumerable<ParameterSyntax> StubParameters
{
foreach (var typeinfo in paramsTypeInfo)
{
//if (typeinfo.ManagedIndex != TypePositionInfo.UnsetIndex)
if (typeinfo.ManagedIndex != TypePositionInfo.UnsetIndex
&& typeinfo.ManagedIndex != TypePositionInfo.ReturnIndex)
{
yield return Parameter(Identifier(typeinfo.InstanceIdentifier))
.WithType(typeinfo.ManagedType.AsTypeSyntax())
Expand Down Expand Up @@ -144,16 +145,35 @@ public static DllImportStub Create(

// Determine parameter and return types
var paramsTypeInfo = new List<TypePositionInfo>();
foreach (var param in method.Parameters)
for (int i = 0; i < method.Parameters.Length; i++)
{
paramsTypeInfo.Add(TypePositionInfo.CreateForParameter(param, compilation));
var param = method.Parameters[i];
var typeInfo = TypePositionInfo.CreateForParameter(param, compilation);
typeInfo.ManagedIndex = i;
typeInfo.NativeIndex = paramsTypeInfo.Count;
paramsTypeInfo.Add(typeInfo);
}

var retTypeInfo = TypePositionInfo.CreateForType(method.ReturnType, method.GetReturnTypeAttributes(), compilation);
TypePositionInfo retTypeInfo = TypePositionInfo.CreateForType(method.ReturnType, method.GetReturnTypeAttributes(), compilation);
retTypeInfo.ManagedIndex = TypePositionInfo.ReturnIndex;
retTypeInfo.NativeIndex = TypePositionInfo.ReturnIndex;
if (!dllImportData.PreserveSig)
{
// [TODO] Create type info for native HRESULT return
// retTypeInfo = ...

// [TODO] Create type info for native out param
// if (!method.ReturnsVoid)
// {
// TypePositionInfo nativeOutInfo = ...;
// nativeOutInfo.ManagedIndex = TypePositionInfo.ReturnIndex;
// nativeOutInfo.NativeIndex = paramsTypeInfo.Count;
// paramsTypeInfo.Add(nativeOutInfo);
// }
}

// Generate stub code
string dllImportName = method.Name + "__PInvoke__";
var (code, dllImport) = StubCodeContext.GenerateSyntax(dllImportName, paramsTypeInfo, retTypeInfo);
var (code, dllImport) = StubCodeContext.GenerateSyntax(method, paramsTypeInfo, retTypeInfo);

return new DllImportStub()
{
Expand Down
26 changes: 13 additions & 13 deletions DllImportGenerator/DllImportGenerator/Marshalling/BoolMarshaller.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,18 @@ public TypeSyntax AsNativeType(TypePositionInfo info)
return info.NativeType.AsTypeSyntax();
}

public ArgumentSyntax AsArgument(TypePositionInfo info)
public ParameterSyntax AsParameter(TypePositionInfo info)
{
var type = info.IsByRef
? PointerType(info.NativeType.AsTypeSyntax())
: info.NativeType.AsTypeSyntax();
return Parameter(Identifier(info.InstanceIdentifier))
.WithType(type);
}

public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context)
{
string identifier = StubCodeContext.ToNativeIdentifer(info.InstanceIdentifier);
string identifier = context.GetIdentifiers(info).native;
if (info.IsByRef)
{
return Argument(
Expand All @@ -27,22 +36,13 @@ public ArgumentSyntax AsArgument(TypePositionInfo info)
return Argument(IdentifierName(identifier));
}

public ParameterSyntax AsParameter(TypePositionInfo info)
{
var type = info.IsByRef
? PointerType(info.NativeType.AsTypeSyntax())
: info.NativeType.AsTypeSyntax();
return Parameter(Identifier(info.InstanceIdentifier))
.WithType(type);
}

public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeContext context)
{
(string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info);
switch (context.CurrentStage)
{
case StubCodeContext.Stage.Setup:
if (info.IsReturnType)
if (info.IsManagedReturnPosition)
nativeIdentifier = context.GenerateReturnNativeIdentifier();

yield return LocalDeclarationStatement(
Expand All @@ -69,7 +69,7 @@ public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeCont

break;
case StubCodeContext.Stage.Unmarshal:
if (info.IsReturnType || info.IsByRef)
if (info.IsManagedReturnPosition || info.IsByRef)
{
// <managedIdentifier> = <nativeIdentifier> != 0;
yield return ExpressionStatement(
Expand Down
12 changes: 6 additions & 6 deletions DllImportGenerator/DllImportGenerator/Marshalling/Forwarder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ public TypeSyntax AsNativeType(TypePositionInfo info)
return info.ManagedType.AsTypeSyntax();
}

public ArgumentSyntax AsArgument(TypePositionInfo info)
{
return Argument(IdentifierName(info.InstanceIdentifier))
.WithRefKindKeyword(Token(info.RefKindSyntax));
}

public ParameterSyntax AsParameter(TypePositionInfo info)
{
return Parameter(Identifier(info.InstanceIdentifier))
.WithModifiers(TokenList(Token(info.RefKindSyntax)))
.WithType(info.ManagedType.AsTypeSyntax());
}

public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context)
{
return Argument(IdentifierName(info.InstanceIdentifier))
.WithRefKindKeyword(Token(info.RefKindSyntax));
}

public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeContext context)
{
return Array.Empty<StatementSyntax>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@ internal interface IMarshallingGenerator
TypeSyntax AsNativeType(TypePositionInfo info);

/// <summary>
/// Get the <paramref name="info"/> as an argument to be passed to the P/Invoke
/// Get the <paramref name="info"/> as a parameter of the P/Invoke declaration
/// </summary>
/// <param name="info">Object to marshal</param>
/// <returns>Argument syntax for <paramref name="info"/></returns>
ArgumentSyntax AsArgument(TypePositionInfo info);
/// <returns>Parameter syntax for <paramref name="info"/></returns>
ParameterSyntax AsParameter(TypePositionInfo info);

/// <summary>
/// Get the <paramref name="info"/> as a parameter of the P/Invoke declaration
/// Get the <paramref name="info"/> as an argument to be passed to the P/Invoke
/// </summary>
/// <param name="info">Object to marshal</param>
/// <returns>Parameter syntax for <paramref name="info"/></returns>
ParameterSyntax AsParameter(TypePositionInfo info);
/// <param name="context">Code generation context</param>
/// <returns>Argument syntax for <paramref name="info"/></returns>
ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context);

/// <summary>
/// Generate code for marshalling
Expand Down Expand Up @@ -58,6 +59,12 @@ public static bool TryCreate(TypePositionInfo info, out IMarshallingGenerator ge
generator = MarshallingGenerators.Forwarder;
return true;
#else
if (info.IsNativeReturnPosition && !info.IsManagedReturnPosition)
{
// [TODO] Use marshaller for native HRESULT return / exception throwing
// Debug.Assert(info.ManagedType.SpecialType == SpecialType.System_Int32)
}

switch (info.ManagedType.SpecialType)
{
case SpecialType.System_SByte:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,31 @@ public TypeSyntax AsNativeType(TypePositionInfo info)
return info.NativeType.AsTypeSyntax();
}

public ArgumentSyntax AsArgument(TypePositionInfo info)
public ParameterSyntax AsParameter(TypePositionInfo info)
{
var type = info.IsByRef
? PointerType(info.NativeType.AsTypeSyntax())
: info.NativeType.AsTypeSyntax();
return Parameter(Identifier(info.InstanceIdentifier))
.WithType(type);
}

public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context)
{
if (info.IsByRef)
{
string identifier = StubCodeContext.ToNativeIdentifer(info.InstanceIdentifier);
return Argument(
PrefixUnaryExpression(
SyntaxKind.AddressOfExpression,
IdentifierName(identifier)));
IdentifierName(context.GetIdentifiers(info).native)));
}

return Argument(IdentifierName(info.InstanceIdentifier));
}

public ParameterSyntax AsParameter(TypePositionInfo info)
{
var type = info.IsByRef
? PointerType(info.NativeType.AsTypeSyntax())
: info.NativeType.AsTypeSyntax();
return Parameter(Identifier(info.InstanceIdentifier))
.WithType(type);
}

public IEnumerable<StatementSyntax> Generate(TypePositionInfo info, StubCodeContext context)
{
if (!info.IsByRef)
if (!info.IsByRef || info.IsManagedReturnPosition)
yield break;

(string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info);
Expand Down
84 changes: 57 additions & 27 deletions DllImportGenerator/DllImportGenerator/StubCodeContext.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;

using Microsoft.CodeAnalysis;
Expand Down Expand Up @@ -56,15 +57,15 @@ public enum Stage
/// <summary>
/// Identifier for managed return value
/// </summary>
public string ReturnIdentifier => returnIdentifier;
public const string ReturnIdentifier = "__retVal";

/// <summary>
/// Identifier for native return value
/// </summary>
/// <remarks>Same as the managed identifier by default</remarks>
public string ReturnNativeIdentifier { get; private set; } = returnIdentifier;
public string ReturnNativeIdentifier { get; private set; } = ReturnIdentifier;

private const string returnIdentifier = "__retVal";
private const string InvokeReturnIdentifier = "__invokeRetVal";
private const string generatedNativeIdentifierSuffix = "_gen_native";

private StubCodeContext(Stage stage)
Expand Down Expand Up @@ -93,27 +94,40 @@ public string GenerateReturnNativeIdentifier()
/// <returns>Managed and native identifiers</returns>
public (string managed, string native) GetIdentifiers(TypePositionInfo info)
{
string managedIdentifier = info.IsReturnType
? ReturnIdentifier
: info.InstanceIdentifier;
string managedIdentifier;
string nativeIdentifier;
if (info.IsManagedReturnPosition && !info.IsNativeReturnPosition)
{
managedIdentifier = ReturnIdentifier;
nativeIdentifier = ReturnNativeIdentifier;
}
else if (!info.IsManagedReturnPosition && info.IsNativeReturnPosition)
{
managedIdentifier = InvokeReturnIdentifier;
nativeIdentifier = InvokeReturnIdentifier;
}
else
{
managedIdentifier = info.IsManagedReturnPosition
? ReturnIdentifier
: info.InstanceIdentifier;

string nativeIdentifier = info.IsReturnType
? ReturnNativeIdentifier
: ToNativeIdentifer(info.InstanceIdentifier);
nativeIdentifier = info.IsNativeReturnPosition
? ReturnNativeIdentifier
: $"__{info.InstanceIdentifier}{generatedNativeIdentifierSuffix}";
}

return (managedIdentifier, nativeIdentifier);
}

public static string ToNativeIdentifer(string managedIdentifier)
{
return $"__{managedIdentifier}{generatedNativeIdentifierSuffix}";
}

public static (BlockSyntax Code, MethodDeclarationSyntax DllImport) GenerateSyntax(
string dllImportName,
IMethodSymbol stubMethod,
IEnumerable<TypePositionInfo> paramsTypeInfo,
TypePositionInfo retTypeInfo)
{
Debug.Assert(retTypeInfo.IsNativeReturnPosition);

string dllImportName = stubMethod.Name + "__PInvoke__";
var paramMarshallers = paramsTypeInfo.Select(p => GetMarshalInfo(p)).ToList();
var retMarshaller = GetMarshalInfo(retTypeInfo);

Expand All @@ -123,7 +137,7 @@ public static (BlockSyntax Code, MethodDeclarationSyntax DllImport) GenerateSynt
foreach (var marshaller in paramMarshallers)
{
TypePositionInfo info = marshaller.TypeInfo;
if (info.RefKind != RefKind.Out)
if (info.RefKind != RefKind.Out || info.IsManagedReturnPosition)
continue;

// Assign out params to default
Expand All @@ -136,15 +150,31 @@ public static (BlockSyntax Code, MethodDeclarationSyntax DllImport) GenerateSynt
Token(SyntaxKind.DefaultKeyword)))));
}

bool returnsVoid = retTypeInfo.ManagedType.SpecialType == SpecialType.System_Void;
if (!returnsVoid)
bool invokeReturnsVoid = retTypeInfo.ManagedType.SpecialType == SpecialType.System_Void;
bool stubReturnsVoid = stubMethod.ReturnsVoid;

// Stub return is not the same as invoke return
if (!stubReturnsVoid && !retTypeInfo.IsManagedReturnPosition)
{
Debug.Assert(paramsTypeInfo.Any() && paramsTypeInfo.Last().IsManagedReturnPosition);

// Declare variable for stub return value
TypePositionInfo info = paramsTypeInfo.Last();
statements.Add(LocalDeclarationStatement(
VariableDeclaration(
info.ManagedType.AsTypeSyntax(),
SingletonSeparatedList(
VariableDeclarator(context.GetIdentifiers(info).managed)))));
}

if (!invokeReturnsVoid)
{
// Declare variable for return value
// Declare variable for invoke return value
statements.Add(LocalDeclarationStatement(
VariableDeclaration(
retMarshaller.TypeInfo.ManagedType.AsTypeSyntax(),
retTypeInfo.ManagedType.AsTypeSyntax(),
SingletonSeparatedList(
VariableDeclarator(context.ReturnIdentifier)))));
VariableDeclarator(context.GetIdentifiers(retTypeInfo).managed)))));
}

var stages = new Stage[]
Expand All @@ -164,7 +194,7 @@ public static (BlockSyntax Code, MethodDeclarationSyntax DllImport) GenerateSynt
int initialCount = statements.Count;
context.CurrentStage = stage;

if (!returnsVoid && (stage == Stage.Setup || stage == Stage.Unmarshal))
if (!invokeReturnsVoid && (stage == Stage.Setup || stage == Stage.Unmarshal))
{
// Handle setup and unmarshalling for return
var retStatements = retMarshaller.Generator.Generate(retMarshaller.TypeInfo, context);
Expand All @@ -177,7 +207,7 @@ public static (BlockSyntax Code, MethodDeclarationSyntax DllImport) GenerateSynt
if (stage == Stage.Invoke)
{
// Get arguments for invocation
ArgumentSyntax argSyntax = marshaller.Generator.AsArgument(marshaller.TypeInfo);
ArgumentSyntax argSyntax = marshaller.Generator.AsArgument(marshaller.TypeInfo, context);
invoke = invoke.AddArgumentListArguments(argSyntax);
}
else
Expand Down Expand Up @@ -206,7 +236,7 @@ public static (BlockSyntax Code, MethodDeclarationSyntax DllImport) GenerateSynt
StatementSyntax invokeStatement;

// Assign to return value if necessary
if (returnsVoid)
if (invokeReturnsVoid)
{
invokeStatement = ExpressionStatement(invoke);
}
Expand All @@ -215,7 +245,7 @@ public static (BlockSyntax Code, MethodDeclarationSyntax DllImport) GenerateSynt
invokeStatement = ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(context.ReturnNativeIdentifier),
IdentifierName(context.GetIdentifiers(retMarshaller.TypeInfo).native),
invoke));
}

Expand Down Expand Up @@ -247,8 +277,8 @@ public static (BlockSyntax Code, MethodDeclarationSyntax DllImport) GenerateSynt
}

// Return
if (!returnsVoid)
statements.Add(ReturnStatement(IdentifierName(context.ReturnIdentifier)));
if (!stubReturnsVoid)
statements.Add(ReturnStatement(IdentifierName(ReturnIdentifier)));

// Wrap all statements in an unsafe block
var codeBlock = Block(UnsafeStatement(Block(statements)));
Expand Down
Loading