From c7a449fe1c63bcb1a444345e0da7ded499202f70 Mon Sep 17 00:00:00 2001 From: Wojtek Figat Date: Thu, 15 Feb 2024 18:28:51 +0100 Subject: [PATCH] Fix marshaling custom type array to C# with `MarshalAs` used --- .../Bindings/BindingsGenerator.CSharp.cs | 72 ++++++++++++++----- .../Bindings/BindingsGenerator.Cpp.cs | 43 ++++++++--- 2 files changed, 88 insertions(+), 27 deletions(-) diff --git a/Source/Tools/Flax.Build/Bindings/BindingsGenerator.CSharp.cs b/Source/Tools/Flax.Build/Bindings/BindingsGenerator.CSharp.cs index dc9f74369..e1d4df083 100644 --- a/Source/Tools/Flax.Build/Bindings/BindingsGenerator.CSharp.cs +++ b/Source/Tools/Flax.Build/Bindings/BindingsGenerator.CSharp.cs @@ -270,7 +270,7 @@ namespace Flax.Build.Bindings return value; } - private static string GenerateCSharpNativeToManaged(BuildData buildData, TypeInfo typeInfo, ApiTypeInfo caller) + private static string GenerateCSharpNativeToManaged(BuildData buildData, TypeInfo typeInfo, ApiTypeInfo caller, bool marshalling = false) { string result; if (typeInfo?.Type == null) @@ -280,7 +280,7 @@ namespace Flax.Build.Bindings if (typeInfo.IsArray) { typeInfo.IsArray = false; - result = GenerateCSharpNativeToManaged(buildData, typeInfo, caller); + result = GenerateCSharpNativeToManaged(buildData, typeInfo, caller, marshalling); typeInfo.IsArray = true; return result + "[]"; } @@ -307,7 +307,7 @@ namespace Flax.Build.Bindings // Object reference property if (typeInfo.IsObjectRef) - return GenerateCSharpNativeToManaged(buildData, typeInfo.GenericArgs[0], caller); + return GenerateCSharpNativeToManaged(buildData, typeInfo.GenericArgs[0], caller, marshalling); if (typeInfo.Type == "SoftTypeReference" || typeInfo.Type == "SoftObjectReference") return typeInfo.Type; @@ -317,15 +317,25 @@ namespace Flax.Build.Bindings #else if ((typeInfo.Type == "Array" || typeInfo.Type == "Span" || typeInfo.Type == "DataContainer") && typeInfo.GenericArgs != null) #endif - return GenerateCSharpNativeToManaged(buildData, typeInfo.GenericArgs[0], caller) + "[]"; + { + var arrayTypeInfo = typeInfo.GenericArgs[0]; + if (marshalling) + { + // Convert array that uses different type for marshalling + var arrayApiType = FindApiTypeInfo(buildData, arrayTypeInfo, caller); + if (arrayApiType != null && arrayApiType.MarshalAs != null) + arrayTypeInfo = arrayApiType.MarshalAs; + } + return GenerateCSharpNativeToManaged(buildData, arrayTypeInfo, caller) + "[]"; + } // Dictionary if (typeInfo.Type == "Dictionary" && typeInfo.GenericArgs != null) - return string.Format("System.Collections.Generic.Dictionary<{0}, {1}>", GenerateCSharpNativeToManaged(buildData, typeInfo.GenericArgs[0], caller), GenerateCSharpNativeToManaged(buildData, typeInfo.GenericArgs[1], caller)); + return string.Format("System.Collections.Generic.Dictionary<{0}, {1}>", GenerateCSharpNativeToManaged(buildData, typeInfo.GenericArgs[0], caller, marshalling), GenerateCSharpNativeToManaged(buildData, typeInfo.GenericArgs[1], caller, marshalling)); // HashSet if (typeInfo.Type == "HashSet" && typeInfo.GenericArgs != null) - return string.Format("System.Collections.Generic.HashSet<{0}>", GenerateCSharpNativeToManaged(buildData, typeInfo.GenericArgs[0], caller)); + return string.Format("System.Collections.Generic.HashSet<{0}>", GenerateCSharpNativeToManaged(buildData, typeInfo.GenericArgs[0], caller, marshalling)); // BitArray if (typeInfo.Type == "BitArray" && typeInfo.GenericArgs != null) @@ -348,16 +358,16 @@ namespace Flax.Build.Bindings // TODO: generate delegates globally in the module namespace to share more code (smaller binary size) var key = string.Empty; for (int i = 0; i < typeInfo.GenericArgs.Count; i++) - key += GenerateCSharpNativeToManaged(buildData, typeInfo.GenericArgs[i], caller); + key += GenerateCSharpNativeToManaged(buildData, typeInfo.GenericArgs[i], caller, marshalling); if (!CSharpAdditionalCodeCache.TryGetValue(key, out var delegateName)) { delegateName = "Delegate" + CSharpAdditionalCodeCache.Count; - var signature = $"public delegate {GenerateCSharpNativeToManaged(buildData, typeInfo.GenericArgs[0], caller)} {delegateName}("; + var signature = $"public delegate {GenerateCSharpNativeToManaged(buildData, typeInfo.GenericArgs[0], caller, marshalling)} {delegateName}("; for (int i = 1; i < typeInfo.GenericArgs.Count; i++) { if (i != 1) signature += ", "; - signature += GenerateCSharpNativeToManaged(buildData, typeInfo.GenericArgs[i], caller); + signature += GenerateCSharpNativeToManaged(buildData, typeInfo.GenericArgs[i], caller, marshalling); signature += $" arg{(i - 1)}"; } signature += ");"; @@ -390,11 +400,14 @@ namespace Flax.Build.Bindings { typeName += '<'; foreach (var arg in typeInfo.GenericArgs) - typeName += GenerateCSharpNativeToManaged(buildData, arg, caller); + typeName += GenerateCSharpNativeToManaged(buildData, arg, caller, marshalling); typeName += '>'; } if (apiType != null) { + if (marshalling && apiType.MarshalAs != null) + return GenerateCSharpNativeToManaged(buildData, apiType.MarshalAs, caller); + // Add reference to the namespace CSharpUsedNamespaces.Add(apiType.Namespace); var apiTypeParent = apiType.Parent; @@ -419,11 +432,11 @@ namespace Flax.Build.Bindings return typeName; } - private static string GenerateCSharpManagedToNativeType(BuildData buildData, TypeInfo typeInfo, ApiTypeInfo caller) + private static string GenerateCSharpManagedToNativeType(BuildData buildData, TypeInfo typeInfo, ApiTypeInfo caller, bool marshalling = false) { // Fixed-size array if (typeInfo.IsArray) - return GenerateCSharpNativeToManaged(buildData, typeInfo, caller); + return GenerateCSharpNativeToManaged(buildData, typeInfo, caller, marshalling); // Find API type info var apiType = FindApiTypeInfo(buildData, typeInfo, caller); @@ -439,7 +452,7 @@ namespace Flax.Build.Bindings } if (apiType.MarshalAs != null) - return GenerateCSharpManagedToNativeType(buildData, apiType.MarshalAs, caller); + return GenerateCSharpManagedToNativeType(buildData, apiType.MarshalAs, caller, marshalling); if (apiType.IsScriptingObject || apiType.IsInterface) return "IntPtr"; } @@ -452,7 +465,7 @@ namespace Flax.Build.Bindings if (typeInfo.Type == "Function" && typeInfo.GenericArgs != null) return "IntPtr"; - return GenerateCSharpNativeToManaged(buildData, typeInfo, caller); + return GenerateCSharpNativeToManaged(buildData, typeInfo, caller, marshalling); } private static string GenerateCSharpManagedToNativeConverter(BuildData buildData, TypeInfo typeInfo, ApiTypeInfo caller) @@ -485,6 +498,18 @@ namespace Flax.Build.Bindings case "Function": // delegate return "NativeInterop.GetFunctionPointerForDelegate({0})"; + case "Array": + case "Span": + case "DataContainer": + if (typeInfo.GenericArgs != null) + { + // Convert array that uses different type for marshalling + var arrayTypeInfo = typeInfo.GenericArgs[0]; + var arrayApiType = FindApiTypeInfo(buildData, arrayTypeInfo, caller); + if (arrayApiType != null && arrayApiType.MarshalAs != null) + return $"{{0}}.ConvertArray(x => ({GenerateCSharpNativeToManaged(buildData, arrayApiType.MarshalAs, caller)})x)"; + } + return string.Empty; default: var apiType = FindApiTypeInfo(buildData, typeInfo, caller); if (apiType != null) @@ -531,9 +556,9 @@ namespace Flax.Build.Bindings { var apiType = FindApiTypeInfo(buildData, functionInfo.ReturnType, caller); if (apiType != null && apiType.MarshalAs != null) - returnValueType = GenerateCSharpNativeToManaged(buildData, apiType.MarshalAs, caller); + returnValueType = GenerateCSharpNativeToManaged(buildData, apiType.MarshalAs, caller, true); else - returnValueType = GenerateCSharpNativeToManaged(buildData, functionInfo.ReturnType, caller); + returnValueType = GenerateCSharpNativeToManaged(buildData, functionInfo.ReturnType, caller, true); } #if USE_NETCORE @@ -594,7 +619,7 @@ namespace Flax.Build.Bindings contents.Append(", "); separator = true; - var nativeType = GenerateCSharpManagedToNativeType(buildData, parameterInfo.Type, caller); + var nativeType = GenerateCSharpManagedToNativeType(buildData, parameterInfo.Type, caller, true); #if USE_NETCORE string parameterMarshalType = ""; if (nativeType == "System.Type") @@ -643,7 +668,7 @@ namespace Flax.Build.Bindings contents.Append(", "); separator = true; - var nativeType = GenerateCSharpManagedToNativeType(buildData, parameterInfo.Type, caller); + var nativeType = GenerateCSharpManagedToNativeType(buildData, parameterInfo.Type, caller, true); #if USE_NETCORE string parameterMarshalType = ""; if (parameterInfo.IsOut && parameterInfo.DefaultValue == "var __resultAsRef") @@ -756,7 +781,16 @@ namespace Flax.Build.Bindings } } - contents.Append(");"); + contents.Append(')'); + if ((functionInfo.ReturnType.Type == "Array" || functionInfo.ReturnType.Type == "Span" || functionInfo.ReturnType.Type == "DataContainer") && functionInfo.ReturnType.GenericArgs != null) + { + // Convert array that uses different type for marshalling + var arrayTypeInfo = functionInfo.ReturnType.GenericArgs[0]; + var arrayApiType = FindApiTypeInfo(buildData, arrayTypeInfo, caller); + if (arrayApiType != null && arrayApiType.MarshalAs != null) + contents.Append($".ConvertArray(x => ({GenerateCSharpNativeToManaged(buildData, arrayTypeInfo, caller)})x)"); + } + contents.Append(';'); // Return result if (functionInfo.Glue.UseReferenceForResult) diff --git a/Source/Tools/Flax.Build/Bindings/BindingsGenerator.Cpp.cs b/Source/Tools/Flax.Build/Bindings/BindingsGenerator.Cpp.cs index 8e1836b2f..35d1d32fa 100644 --- a/Source/Tools/Flax.Build/Bindings/BindingsGenerator.Cpp.cs +++ b/Source/Tools/Flax.Build/Bindings/BindingsGenerator.Cpp.cs @@ -305,7 +305,7 @@ namespace Flax.Build.Bindings private static string GenerateCppGetMClass(BuildData buildData, TypeInfo typeInfo, ApiTypeInfo caller, FunctionInfo functionInfo) { // Optimal path for in-build types - var managedType = GenerateCSharpNativeToManaged(buildData, typeInfo, caller); + var managedType = GenerateCSharpNativeToManaged(buildData, typeInfo, caller, true); switch (managedType) { // In-built types (cached by the engine on startup) @@ -388,7 +388,7 @@ namespace Flax.Build.Bindings CppIncludeFiles.Add("Engine/Scripting/ManagedCLR/MClass.h"); // Optimal path for in-build types - var managedType = GenerateCSharpNativeToManaged(buildData, typeInfo, caller); + var managedType = GenerateCSharpNativeToManaged(buildData, typeInfo, caller, true); switch (managedType) { case "bool": @@ -519,16 +519,28 @@ namespace Flax.Build.Bindings // Array or DataContainer if ((typeInfo.Type == "Array" || typeInfo.Type == "Span" || typeInfo.Type == "DataContainer") && typeInfo.GenericArgs != null) { + var arrayTypeInfo = typeInfo.GenericArgs[0]; #if USE_NETCORE // Boolean arrays does not support custom marshalling for some unknown reason - if (typeInfo.GenericArgs[0].Type == "bool") + if (arrayTypeInfo.Type == "bool") { type = "bool*"; return "MUtils::ToBoolArray({0})"; } + var arrayApiType = FindApiTypeInfo(buildData, arrayTypeInfo, caller); #endif type = "MArray*"; - return "MUtils::ToArray({0}, " + GenerateCppGetMClass(buildData, typeInfo.GenericArgs[0], caller, functionInfo) + ")"; + if (arrayApiType != null && arrayApiType.MarshalAs != null) + { + // Convert array that uses different type for marshalling + if (arrayApiType != null && arrayApiType.MarshalAs != null) + arrayTypeInfo = arrayApiType.MarshalAs; // Convert array that uses different type for marshalling + var genericArgs = arrayApiType.MarshalAs.GetFullNameNative(buildData, caller); + if (typeInfo.GenericArgs.Count != 1) + genericArgs += ", " + typeInfo.GenericArgs[1]; + return "MUtils::ToArray(Array<" + genericArgs + ">({0}), " + GenerateCppGetMClass(buildData, arrayTypeInfo, caller, functionInfo) + ")"; + } + return "MUtils::ToArray({0}, " + GenerateCppGetMClass(buildData, arrayTypeInfo, caller, functionInfo) + ")"; } // Span @@ -719,11 +731,26 @@ namespace Flax.Build.Bindings // Array if (typeInfo.Type == "Array" && typeInfo.GenericArgs != null) { - var T = typeInfo.GenericArgs[0].GetFullNameNative(buildData, caller); - type = "MArray*"; + var arrayTypeInfo = typeInfo.GenericArgs[0]; + var arrayApiType = FindApiTypeInfo(buildData, arrayTypeInfo, caller); + if (arrayApiType != null && arrayApiType.MarshalAs != null) + arrayTypeInfo = arrayApiType.MarshalAs; + var genericArgs = arrayTypeInfo.GetFullNameNative(buildData, caller); if (typeInfo.GenericArgs.Count != 1) - return "MUtils::ToArray<" + T + ", " + typeInfo.GenericArgs[1] + ">({0})"; - return "MUtils::ToArray<" + T + ">({0})"; + genericArgs += ", " + typeInfo.GenericArgs[1]; + + type = "MArray*"; + var result = "MUtils::ToArray<" + genericArgs + ">({0})"; + + if (arrayApiType != null && arrayApiType.MarshalAs != null) + { + // Convert array that uses different type for marshalling + genericArgs = typeInfo.GenericArgs[0].GetFullNameNative(buildData, caller); + if (typeInfo.GenericArgs.Count != 1) + genericArgs += ", " + typeInfo.GenericArgs[1]; + result = $"Array<{genericArgs}>({result})"; + } + return result; } // Span or DataContainer