Add support for interfaces in scripting API (cross language support C++/C#/VS)

This commit is contained in:
Wojtek Figat
2021-10-04 12:22:28 +02:00
parent 147e5ada46
commit c3c0a4ef0d
16 changed files with 1039 additions and 355 deletions

View File

@@ -233,7 +233,7 @@ namespace Flax.Build.Bindings
{
CSharpUsedNamespaces.Add(apiType.Namespace);
if (apiType.IsScriptingObject)
if (apiType.IsScriptingObject || apiType.IsInterface)
return typeInfo.Type.Replace("::", ".");
if (typeInfo.IsPtr && apiType.IsPod)
return typeInfo.Type.Replace("::", ".") + '*';
@@ -256,7 +256,7 @@ namespace Flax.Build.Bindings
var apiType = FindApiTypeInfo(buildData, typeInfo, caller);
if (apiType != null)
{
if (apiType.IsScriptingObject)
if (apiType.IsScriptingObject || apiType.IsInterface)
return "IntPtr";
}
@@ -297,22 +297,26 @@ namespace Flax.Build.Bindings
case "PersistentScriptingObject":
case "ManagedScriptingObject":
// object
return "FlaxEngine.Object.GetUnmanagedPtr";
return "FlaxEngine.Object.GetUnmanagedPtr({0})";
case "Function":
// delegate
return "Marshal.GetFunctionPointerForDelegate";
return "Marshal.GetFunctionPointerForDelegate({0})";
default:
var apiType = FindApiTypeInfo(buildData, typeInfo, caller);
if (apiType != null)
{
// Scripting Object
if (apiType.IsScriptingObject)
return "FlaxEngine.Object.GetUnmanagedPtr";
return "FlaxEngine.Object.GetUnmanagedPtr({0})";
// interface
if (apiType.IsInterface)
return string.Format("FlaxEngine.Object.GetUnmanagedInterface({{0}}, typeof({0}))", apiType.FullNameManaged);
}
// ScriptingObjectReference or AssetReference or WeakAssetReference or SoftObjectReference
if ((typeInfo.Type == "ScriptingObjectReference" || typeInfo.Type == "AssetReference" || typeInfo.Type == "WeakAssetReference" || typeInfo.Type == "SoftObjectReference") && typeInfo.GenericArgs != null)
return "FlaxEngine.Object.GetUnmanagedPtr";
return "FlaxEngine.Object.GetUnmanagedPtr({0})";
// Default
return string.Empty;
@@ -358,8 +362,10 @@ namespace Flax.Build.Bindings
contents.Append(parameterInfo.Name);
}
foreach (var parameterInfo in functionInfo.Glue.CustomParameters)
var customParametersCount = functionInfo.Glue.CustomParameters?.Count ?? 0;
for (var i = 0; i < customParametersCount; i++)
{
var parameterInfo = functionInfo.Glue.CustomParameters[i];
if (separator)
contents.Append(", ");
separator = true;
@@ -424,15 +430,14 @@ namespace Flax.Build.Bindings
throw new Exception($"Cannot use Ref meta on parameter {parameterInfo} in function {functionInfo.Name} in {caller}.");
// Convert value
contents.Append(convertFunc);
contents.Append('(');
contents.Append(isSetter ? "value" : parameterInfo.Name);
contents.Append(')');
contents.Append(string.Format(convertFunc, isSetter ? "value" : parameterInfo.Name));
}
}
foreach (var parameterInfo in functionInfo.Glue.CustomParameters)
var customParametersCount = functionInfo.Glue.CustomParameters?.Count ?? 0;
for (var i = 0; i < customParametersCount; i++)
{
var parameterInfo = functionInfo.Glue.CustomParameters[i];
if (separator)
contents.Append(", ");
separator = true;
@@ -451,10 +456,7 @@ namespace Flax.Build.Bindings
else
{
// Convert value
contents.Append(convertFunc);
contents.Append('(');
contents.Append(parameterInfo.DefaultValue);
contents.Append(')');
contents.Append(string.Format(convertFunc, parameterInfo.DefaultValue));
}
}
@@ -571,8 +573,24 @@ namespace Flax.Build.Bindings
else if (classInfo.IsAbstract)
contents.Append("abstract ");
contents.Append("unsafe partial class ").Append(classInfo.Name);
if (classInfo.BaseType != null && !classInfo.IsBaseTypeHidden)
var hasBase = classInfo.BaseType != null && !classInfo.IsBaseTypeHidden;
if (hasBase)
contents.Append(" : ").Append(GenerateCSharpNativeToManaged(buildData, new TypeInfo { Type = classInfo.BaseType.Name }, classInfo));
var hasInterface = false;
if (classInfo.Interfaces != null)
{
foreach (var interfaceInfo in classInfo.Interfaces)
{
if (interfaceInfo.Access != AccessLevel.Public)
continue;
if (hasInterface || hasBase)
contents.Append(", ");
else
contents.Append(" : ");
hasInterface = true;
contents.Append(interfaceInfo.FullNameManaged);
}
}
contents.AppendLine();
contents.Append(indent + "{");
indent += " ";
@@ -902,6 +920,71 @@ namespace Flax.Build.Bindings
GenerateCSharpWrapperFunction(buildData, contents, indent, classInfo, functionInfo);
}
// Interface implementation
if (hasInterface)
{
foreach (var interfaceInfo in classInfo.Interfaces)
{
if (interfaceInfo.Access != AccessLevel.Public)
continue;
foreach (var functionInfo in interfaceInfo.Functions)
{
if (!classInfo.IsScriptingObject)
throw new Exception($"Class {classInfo.Name} cannot implement interface {interfaceInfo.Name} because it requires ScriptingObject as a base class.");
contents.AppendLine();
if (functionInfo.Comment.Length != 0)
contents.Append(indent).AppendLine("/// <inheritdoc />");
GenerateCSharpAttributes(buildData, contents, indent, classInfo, functionInfo.Attributes, null, false, useUnmanaged);
contents.Append(indent);
if (functionInfo.Access == AccessLevel.Public)
contents.Append("public ");
else if (functionInfo.Access == AccessLevel.Protected)
contents.Append("protected ");
else if (functionInfo.Access == AccessLevel.Private)
contents.Append("private ");
if (functionInfo.IsVirtual && !classInfo.IsSealed)
contents.Append("virtual ");
var returnValueType = GenerateCSharpNativeToManaged(buildData, functionInfo.ReturnType, classInfo);
contents.Append(returnValueType).Append(' ').Append(functionInfo.Name).Append('(');
for (var i = 0; i < functionInfo.Parameters.Count; i++)
{
var parameterInfo = functionInfo.Parameters[i];
if (i != 0)
contents.Append(", ");
if (!string.IsNullOrEmpty(parameterInfo.Attributes))
contents.Append('[').Append(parameterInfo.Attributes).Append(']').Append(' ');
var managedType = GenerateCSharpNativeToManaged(buildData, parameterInfo.Type, classInfo);
if (parameterInfo.IsOut)
contents.Append("out ");
else if (parameterInfo.IsRef)
contents.Append("ref ");
contents.Append(managedType);
contents.Append(' ');
contents.Append(parameterInfo.Name);
var defaultValue = GenerateCSharpDefaultValueNativeToManaged(buildData, parameterInfo.DefaultValue, classInfo);
if (!string.IsNullOrEmpty(defaultValue))
contents.Append(" = ").Append(defaultValue);
}
contents.Append(')').AppendLine().AppendLine(indent + "{");
indent += " ";
contents.Append(indent);
GenerateCSharpWrapperFunctionCall(buildData, contents, classInfo, functionInfo);
indent = indent.Substring(0, indent.Length - 4);
contents.AppendLine();
contents.AppendLine(indent + "}");
GenerateCSharpWrapperFunction(buildData, contents, indent, classInfo, functionInfo);
}
}
}
// Nested types
foreach (var apiTypeInfo in classInfo.Children)
{
@@ -1109,6 +1192,81 @@ namespace Flax.Build.Bindings
}
}
private static void GenerateCSharpInterface(BuildData buildData, StringBuilder contents, string indent, InterfaceInfo interfaceInfo)
{
// Begin
contents.AppendLine();
if (!string.IsNullOrEmpty(interfaceInfo.Namespace))
{
contents.AppendFormat("namespace ");
contents.AppendLine(interfaceInfo.Namespace);
contents.AppendLine("{");
indent += " ";
}
GenerateCSharpComment(contents, indent, interfaceInfo.Comment);
GenerateCSharpAttributes(buildData, contents, indent, interfaceInfo, true);
contents.Append(indent);
if (interfaceInfo.Access == AccessLevel.Public)
contents.Append("public ");
else if (interfaceInfo.Access == AccessLevel.Protected)
contents.Append("protected ");
else if (interfaceInfo.Access == AccessLevel.Private)
contents.Append("private ");
contents.Append("unsafe partial interface ").Append(interfaceInfo.Name);
contents.AppendLine();
contents.Append(indent + "{");
indent += " ";
// Functions
foreach (var functionInfo in interfaceInfo.Functions)
{
if (functionInfo.IsStatic)
throw new Exception($"Not supported {"static"} function {functionInfo.Name} inside interface {interfaceInfo.Name}.");
if (functionInfo.NoProxy)
throw new Exception($"Not supported {"NoProxy"} function {functionInfo.Name} inside interface {interfaceInfo.Name}.");
if (!functionInfo.IsVirtual)
throw new Exception($"Not supported {"non-virtual"} function {functionInfo.Name} inside interface {interfaceInfo.Name}.");
if (functionInfo.Access != AccessLevel.Public)
throw new Exception($"Not supported {"non-public"} function {functionInfo.Name} inside interface {interfaceInfo.Name}.");
contents.AppendLine();
GenerateCSharpComment(contents, indent, functionInfo.Comment);
GenerateCSharpAttributes(buildData, contents, indent, interfaceInfo, functionInfo, true);
var returnValueType = GenerateCSharpNativeToManaged(buildData, functionInfo.ReturnType, interfaceInfo);
contents.Append(indent).Append(returnValueType).Append(' ').Append(functionInfo.Name).Append('(');
for (var i = 0; i < functionInfo.Parameters.Count; i++)
{
var parameterInfo = functionInfo.Parameters[i];
if (i != 0)
contents.Append(", ");
if (!string.IsNullOrEmpty(parameterInfo.Attributes))
contents.Append('[').Append(parameterInfo.Attributes).Append(']').Append(' ');
var managedType = GenerateCSharpNativeToManaged(buildData, parameterInfo.Type, interfaceInfo);
if (parameterInfo.IsOut)
contents.Append("out ");
else if (parameterInfo.IsRef)
contents.Append("ref ");
contents.Append(managedType);
contents.Append(' ');
contents.Append(parameterInfo.Name);
var defaultValue = GenerateCSharpDefaultValueNativeToManaged(buildData, parameterInfo.DefaultValue, interfaceInfo);
if (!string.IsNullOrEmpty(defaultValue))
contents.Append(" = ").Append(defaultValue);
}
contents.Append(");").AppendLine();
}
// End
indent = indent.Substring(0, indent.Length - 4);
contents.AppendLine(indent + "}");
if (!string.IsNullOrEmpty(interfaceInfo.Namespace))
contents.AppendLine("}");
}
private static bool GenerateCSharpType(BuildData buildData, StringBuilder contents, string indent, object type)
{
if (type is ApiTypeInfo apiTypeInfo && apiTypeInfo.IsInBuild)
@@ -1122,6 +1280,8 @@ namespace Flax.Build.Bindings
GenerateCSharpStructure(buildData, contents, indent, structureInfo);
else if (type is EnumInfo enumInfo)
GenerateCSharpEnum(buildData, contents, indent, enumInfo);
else if (type is InterfaceInfo interfaceInfo)
GenerateCSharpInterface(buildData, contents, indent, interfaceInfo);
else
return false;
}

View File

@@ -19,7 +19,7 @@ namespace Flax.Build.Bindings
partial class BindingsGenerator
{
private static readonly Dictionary<string, Type> TypeCache = new Dictionary<string, Type>();
private const int CacheVersion = 8;
private const int CacheVersion = 9;
internal static void Write(BinaryWriter writer, string e)
{

View File

@@ -35,9 +35,9 @@ namespace Flax.Build.Bindings
public static event Action<BuildData, IGrouping<string, Module>, StringBuilder> GenerateCppBinaryModuleHeader;
public static event Action<BuildData, IGrouping<string, Module>, StringBuilder> GenerateCppBinaryModuleSource;
public static event Action<BuildData, ModuleInfo, StringBuilder> GenerateCppModuleSource;
public static event Action<BuildData, ClassInfo, StringBuilder> GenerateCppClassInternals;
public static event Action<BuildData, ClassInfo, StringBuilder> GenerateCppClassInitRuntime;
public static event Action<BuildData, ClassInfo, FunctionInfo, int, int, StringBuilder> GenerateCppScriptWrapperFunction;
public static event Action<BuildData, VirtualClassInfo, StringBuilder> GenerateCppClassInternals;
public static event Action<BuildData, VirtualClassInfo, StringBuilder> GenerateCppClassInitRuntime;
public static event Action<BuildData, VirtualClassInfo, FunctionInfo, int, int, StringBuilder> GenerateCppScriptWrapperFunction;
private static readonly List<string> CppInBuildVariantStructures = new List<string>
{
@@ -458,6 +458,13 @@ namespace Flax.Build.Bindings
return "ScriptingObject::ToManaged((ScriptingObject*){0})";
}
// interface
if (apiType.IsInterface)
{
type = "MonoObject*";
return "ScriptingObject::ToManaged(ScriptingObject::FromInterface({0}, " + apiType.NativeName + "::TypeInitializer))";
}
// Non-POD structure passed as value (eg. it contains string or array inside)
if (apiType.IsStruct && !apiType.IsPod)
{
@@ -977,7 +984,24 @@ namespace Flax.Build.Bindings
contents.AppendLine();
}
private static void GenerateCppManagedWrapperFunction(BuildData buildData, StringBuilder contents, ClassInfo classInfo, FunctionInfo functionInfo, int scriptVTableSize, int scriptVTableIndex)
public static void GenerateCppReturn(BuildData buildData, StringBuilder contents, string indent, TypeInfo type)
{
contents.Append(indent);
if (type.IsVoid)
{
contents.AppendLine("return;");
return;
}
if (type.IsPtr)
{
contents.AppendLine("return nullptr;");
return;
}
contents.AppendLine($"{type} __return {{}};");
contents.Append(indent).AppendLine("return __return;");
}
private static void GenerateCppManagedWrapperFunction(BuildData buildData, StringBuilder contents, VirtualClassInfo classInfo, FunctionInfo functionInfo, int scriptVTableSize, int scriptVTableIndex)
{
contents.AppendFormat(" {0} {1}_ManagedWrapper(", functionInfo.ReturnType, functionInfo.UniqueName);
var separator = false;
@@ -998,7 +1022,23 @@ namespace Flax.Build.Bindings
contents.Append(')');
contents.AppendLine();
contents.AppendLine(" {");
contents.AppendLine($" auto object = ({classInfo.NativeName}*)this;");
string scriptVTableOffset;
if (classInfo.IsInterface)
{
contents.AppendLine($" auto object = ScriptingObject::FromInterface(this, {classInfo.NativeName}::TypeInitializer);");
contents.AppendLine(" if (object == nullptr)");
contents.AppendLine(" {");
contents.AppendLine($" LOG(Error, \"Failed to cast interface {{0}} to scripting object\", TEXT(\"{classInfo.Name}\"));");
GenerateCppReturn(buildData, contents, " ", functionInfo.ReturnType);
contents.AppendLine(" }");
contents.AppendLine($" const int32 scriptVTableOffset = {scriptVTableIndex} + object->GetType().GetInterface({classInfo.NativeName}::TypeInitializer)->ScriptVTableOffset;");
scriptVTableOffset = "scriptVTableOffset";
}
else
{
contents.AppendLine($" auto object = ({classInfo.NativeName}*)this;");
scriptVTableOffset = scriptVTableIndex.ToString();
}
contents.AppendLine(" static THREADLOCAL void* WrapperCallInstance = nullptr;");
contents.AppendLine(" ScriptingTypeHandle managedTypeHandle = object->GetTypeHandle();");
@@ -1013,7 +1053,7 @@ namespace Flax.Build.Bindings
contents.AppendLine(" {");
contents.AppendLine(" // Prevent stack overflow by calling native base method");
contents.AppendLine(" const auto scriptVTableBase = managedTypePtr->Script.ScriptVTableBase;");
contents.Append($" return (this->**({functionInfo.UniqueName}_Internal_Signature*)&scriptVTableBase[{scriptVTableIndex} + 2])(");
contents.Append($" return (this->**({functionInfo.UniqueName}_Internal_Signature*)&scriptVTableBase[{scriptVTableOffset} + 2])(");
separator = false;
for (var i = 0; i < functionInfo.Parameters.Count; i++)
{
@@ -1026,8 +1066,8 @@ namespace Flax.Build.Bindings
contents.AppendLine(");");
contents.AppendLine(" }");
contents.AppendLine(" auto scriptVTable = (MMethod**)managedTypePtr->Script.ScriptVTable;");
contents.AppendLine($" ASSERT(scriptVTable && scriptVTable[{scriptVTableIndex}]);");
contents.AppendLine($" auto method = scriptVTable[{scriptVTableIndex}];");
contents.AppendLine($" ASSERT(scriptVTable && scriptVTable[{scriptVTableOffset}]);");
contents.AppendLine($" auto method = scriptVTable[{scriptVTableOffset}];");
contents.AppendLine($" PROFILE_CPU_NAMED(\"{classInfo.FullNameManaged}::{functionInfo.Name}\");");
contents.AppendLine(" MonoObject* exception = nullptr;");
@@ -1122,6 +1162,107 @@ namespace Flax.Build.Bindings
contents.AppendLine();
}
private static string GenerateCppScriptVTable(BuildData buildData, StringBuilder contents, VirtualClassInfo classInfo)
{
var scriptVTableSize = classInfo.GetScriptVTableSize(out var scriptVTableOffset);
if (scriptVTableSize == 0)
return "nullptr, nullptr";
var scriptVTableIndex = scriptVTableOffset;
foreach (var functionInfo in classInfo.Functions)
{
if (!functionInfo.IsVirtual)
continue;
GenerateCppManagedWrapperFunction(buildData, contents, classInfo, functionInfo, scriptVTableSize, scriptVTableIndex);
GenerateCppScriptWrapperFunction?.Invoke(buildData, classInfo, functionInfo, scriptVTableSize, scriptVTableIndex, contents);
scriptVTableIndex++;
}
CppIncludeFiles.Add("Engine/Scripting/ManagedCLR/MMethod.h");
CppIncludeFiles.Add("Engine/Scripting/ManagedCLR/MClass.h");
foreach (var functionInfo in classInfo.Functions)
{
if (!functionInfo.IsVirtual)
continue;
var thunkParams = string.Empty;
var separator = false;
for (var i = 0; i < functionInfo.Parameters.Count; i++)
{
var parameterInfo = functionInfo.Parameters[i];
if (separator)
thunkParams += ", ";
separator = true;
thunkParams += parameterInfo.Type;
}
var t = functionInfo.IsConst ? " const" : string.Empty;
contents.AppendLine($" typedef {functionInfo.ReturnType} ({classInfo.NativeName}::*{functionInfo.UniqueName}_Signature)({thunkParams}){t};");
contents.AppendLine($" typedef {functionInfo.ReturnType} ({classInfo.NativeName}Internal::*{functionInfo.UniqueName}_Internal_Signature)({thunkParams}){t};");
}
contents.AppendLine("");
contents.AppendLine(" static void SetupScriptVTable(MClass* mclass, void**& scriptVTable, void**& scriptVTableBase)");
contents.AppendLine(" {");
if (classInfo.IsInterface)
{
contents.AppendLine(" ASSERT(scriptVTable);");
}
else
{
contents.AppendLine(" if (!scriptVTable)");
contents.AppendLine(" {");
contents.AppendLine($" scriptVTable = (void**)Platform::Allocate(sizeof(void*) * {scriptVTableSize + 1}, 16);");
contents.AppendLine($" Platform::MemoryClear(scriptVTable, sizeof(void*) * {scriptVTableSize + 1});");
contents.AppendLine($" scriptVTableBase = (void**)Platform::Allocate(sizeof(void*) * {scriptVTableSize + 2}, 16);");
contents.AppendLine(" }");
}
scriptVTableIndex = scriptVTableOffset;
foreach (var functionInfo in classInfo.Functions)
{
if (!functionInfo.IsVirtual)
continue;
contents.AppendLine($" scriptVTable[{scriptVTableIndex++}] = mclass->GetMethod(\"{functionInfo.Name}\", {functionInfo.Parameters.Count});");
}
contents.AppendLine(" }");
contents.AppendLine("");
contents.AppendLine(" static void SetupScriptObjectVTable(void** scriptVTable, void** scriptVTableBase, void** vtable, int32 entriesCount, int32 wrapperIndex)");
contents.AppendLine(" {");
scriptVTableIndex = scriptVTableOffset;
foreach (var functionInfo in classInfo.Functions)
{
if (!functionInfo.IsVirtual)
continue;
contents.AppendLine($" if (scriptVTable[{scriptVTableIndex}])");
contents.AppendLine(" {");
contents.AppendLine($" {functionInfo.UniqueName}_Signature funcPtr = &{classInfo.NativeName}::{functionInfo.Name};");
contents.AppendLine(" const int32 vtableIndex = GetVTableIndex(vtable, entriesCount, *(void**)&funcPtr);");
contents.AppendLine(" if (vtableIndex > 0 && vtableIndex < entriesCount)");
contents.AppendLine(" {");
contents.AppendLine($" scriptVTableBase[{scriptVTableIndex} + 2] = vtable[vtableIndex];");
for (var i = 0; i < CppScriptObjectVirtualWrapperMethodsPostfixes.Count; i++)
{
contents.AppendLine(i == 0 ? " if (wrapperIndex == 0)" : $" else if (wrapperIndex == {i})");
contents.AppendLine(" {");
contents.AppendLine($" auto thunkPtr = &{classInfo.NativeName}Internal::{functionInfo.UniqueName}{CppScriptObjectVirtualWrapperMethodsPostfixes[i]};");
contents.AppendLine(" vtable[vtableIndex] = *(void**)&thunkPtr;");
contents.AppendLine(" }");
}
contents.AppendLine(" }");
contents.AppendLine(" else");
contents.AppendLine(" {");
contents.AppendLine($" LOG(Error, \"Failed to find the vtable entry for method {{0}} in class {{1}}\", TEXT(\"{functionInfo.Name}\"), TEXT(\"{classInfo.Name}\"));");
contents.AppendLine(" }");
contents.AppendLine(" }");
scriptVTableIndex++;
}
contents.AppendLine(" }");
contents.AppendLine("");
return $"&{classInfo.NativeName}Internal::SetupScriptVTable, &{classInfo.NativeName}Internal::SetupScriptObjectVTable";
}
private static string GenerateCppAutoSerializationDefineType(BuildData buildData, StringBuilder contents, ModuleInfo moduleInfo, ApiTypeInfo caller, TypeInfo memberType, MemberInfo member)
{
if (memberType.IsBitField)
@@ -1225,7 +1366,7 @@ namespace Flax.Build.Bindings
contents.Append('}').AppendLine();
}
private static string GenerateCppInterfaceInheritanceTable(BuildData buildData, StringBuilder contents, ModuleInfo moduleInfo, ClassStructInfo typeInfo, string typeNameNative)
private static string GenerateCppInterfaceInheritanceTable(BuildData buildData, StringBuilder contents, ModuleInfo moduleInfo, VirtualClassInfo typeInfo, string typeNameNative)
{
var interfacesPtr = "nullptr";
var interfaces = typeInfo.Interfaces;
@@ -1236,7 +1377,8 @@ namespace Flax.Build.Bindings
for (int i = 0; i < interfaces.Count; i++)
{
var interfaceInfo = interfaces[i];
contents.Append(" { &").Append(interfaceInfo.NativeName).Append("::TypeInitializer, (int16)VTABLE_OFFSET(").Append(typeInfo.NativeName).Append(", ").Append(interfaceInfo.NativeName).AppendLine(") },");
var scriptVTableOffset = typeInfo.GetScriptVTableOffset(interfaceInfo);
contents.AppendLine($" {{ &{interfaceInfo.NativeName}::TypeInitializer, (int16)VTABLE_OFFSET({typeInfo.NativeName}, {interfaceInfo.NativeName}), {scriptVTableOffset}, true }},");
}
contents.AppendLine(" { nullptr, 0 },");
contents.AppendLine("};");
@@ -1253,6 +1395,7 @@ namespace Flax.Build.Bindings
if (classInfo.Parent != null && !(classInfo.Parent is FileInfo))
classTypeNameInternal = classInfo.Parent.FullNameNative + '_' + classTypeNameInternal;
var useScripting = classInfo.IsStatic || classInfo.IsScriptingObject;
var hasInterface = classInfo.Interfaces != null && classInfo.Interfaces.Any(x => x.Access == AccessLevel.Public);
if (classInfo.IsAutoSerialization)
GenerateCppAutoSerialization(buildData, contents, moduleInfo, classInfo, classTypeNameNative);
@@ -1424,110 +1567,27 @@ namespace Flax.Build.Bindings
GenerateCppWrapperFunction(buildData, contents, classInfo, functionInfo);
}
GenerateCppClassInternals?.Invoke(buildData, classInfo, contents);
// Virtual methods overrides
var setupScriptVTable = "nullptr, nullptr";
if (!classInfo.IsSealed)
// Interface implementation
if (hasInterface)
{
var scriptVTableSize = classInfo.GetScriptVTableSize(buildData, out var scriptVTableOffset);
if (scriptVTableSize > 0)
foreach (var interfaceInfo in classInfo.Interfaces)
{
var scriptVTableIndex = scriptVTableOffset;
foreach (var functionInfo in classInfo.Functions)
if (interfaceInfo.Access != AccessLevel.Public)
continue;
foreach (var functionInfo in interfaceInfo.Functions)
{
if (functionInfo.IsVirtual)
{
GenerateCppManagedWrapperFunction(buildData, contents, classInfo, functionInfo, scriptVTableSize, scriptVTableIndex);
GenerateCppScriptWrapperFunction?.Invoke(buildData, classInfo, functionInfo, scriptVTableSize, scriptVTableIndex, contents);
scriptVTableIndex++;
}
}
if (scriptVTableOffset != scriptVTableSize)
{
CppIncludeFiles.Add("Engine/Scripting/ManagedCLR/MMethod.h");
CppIncludeFiles.Add("Engine/Scripting/ManagedCLR/MClass.h");
foreach (var functionInfo in classInfo.Functions)
{
if (!functionInfo.IsVirtual)
continue;
var thunkParams = string.Empty;
var separator = false;
for (var i = 0; i < functionInfo.Parameters.Count; i++)
{
var parameterInfo = functionInfo.Parameters[i];
if (separator)
thunkParams += ", ";
separator = true;
thunkParams += parameterInfo.Type;
}
var t = functionInfo.IsConst ? " const" : string.Empty;
contents.AppendLine($" typedef {functionInfo.ReturnType} ({classInfo.NativeName}::*{functionInfo.UniqueName}_Signature)({thunkParams}){t};");
contents.AppendLine($" typedef {functionInfo.ReturnType} ({classInfo.NativeName}Internal::*{functionInfo.UniqueName}_Internal_Signature)({thunkParams}){t};");
}
contents.AppendLine("");
contents.AppendLine(" static void SetupScriptVTable(MClass* mclass, void**& scriptVTable, void**& scriptVTableBase)");
contents.AppendLine(" {");
contents.AppendLine(" if (!scriptVTable)");
contents.AppendLine(" {");
contents.AppendLine($" scriptVTable = (void**)Platform::Allocate(sizeof(void*) * {scriptVTableSize + 1}, 16);");
contents.AppendLine($" Platform::MemoryClear(scriptVTable, sizeof(void*) * {scriptVTableSize + 1});");
contents.AppendLine($" scriptVTableBase = (void**)Platform::Allocate(sizeof(void*) * {scriptVTableSize + 2}, 16);");
contents.AppendLine(" }");
scriptVTableIndex = scriptVTableOffset;
foreach (var functionInfo in classInfo.Functions)
{
if (!functionInfo.IsVirtual)
continue;
contents.AppendLine($" scriptVTable[{scriptVTableIndex++}] = mclass->GetMethod(\"{functionInfo.Name}\", {functionInfo.Parameters.Count});");
}
contents.AppendLine(" }");
contents.AppendLine("");
contents.AppendLine(" static void SetupScriptObjectVTable(void** scriptVTable, void** scriptVTableBase, void** vtable, int32 entriesCount, int32 wrapperIndex)");
contents.AppendLine(" {");
scriptVTableIndex = scriptVTableOffset;
foreach (var functionInfo in classInfo.Functions)
{
if (!functionInfo.IsVirtual)
continue;
contents.AppendLine($" if (scriptVTable[{scriptVTableIndex}])");
contents.AppendLine(" {");
contents.AppendLine($" {functionInfo.UniqueName}_Signature funcPtr = &{classInfo.NativeName}::{functionInfo.Name};");
contents.AppendLine(" const int32 vtableIndex = GetVTableIndex(vtable, entriesCount, *(void**)&funcPtr);");
contents.AppendLine(" if (vtableIndex > 0 && vtableIndex < entriesCount)");
contents.AppendLine(" {");
contents.AppendLine($" scriptVTableBase[{scriptVTableIndex} + 2] = vtable[vtableIndex];");
for (var i = 0; i < CppScriptObjectVirtualWrapperMethodsPostfixes.Count; i++)
{
contents.AppendLine(i == 0 ? " if (wrapperIndex == 0)" : $" else if (wrapperIndex == {i})");
contents.AppendLine(" {");
contents.AppendLine($" auto thunkPtr = &{classInfo.NativeName}Internal::{functionInfo.UniqueName}{CppScriptObjectVirtualWrapperMethodsPostfixes[i]};");
contents.AppendLine(" vtable[vtableIndex] = *(void**)&thunkPtr;");
contents.AppendLine(" }");
}
contents.AppendLine(" }");
contents.AppendLine(" else");
contents.AppendLine(" {");
contents.AppendLine($" LOG(Error, \"Failed to find the vtable entry for method {{0}} in class {{1}}\", TEXT(\"{functionInfo.Name}\"), TEXT(\"{classInfo.Name}\"));");
contents.AppendLine(" }");
contents.AppendLine(" }");
scriptVTableIndex++;
}
contents.AppendLine(" }");
contents.AppendLine("");
setupScriptVTable = $"&{classInfo.NativeName}Internal::SetupScriptVTable, &{classInfo.NativeName}Internal::SetupScriptObjectVTable";
if (!classInfo.IsScriptingObject)
throw new Exception($"Class {classInfo.Name} cannot implement interface {interfaceInfo.Name} because it requires ScriptingObject as a base class.");
GenerateCppWrapperFunction(buildData, contents, classInfo, functionInfo);
}
}
}
GenerateCppClassInternals?.Invoke(buildData, classInfo, contents);
// Virtual methods overrides
var setupScriptVTable = GenerateCppScriptVTable(buildData, contents, classInfo);
// Runtime initialization (internal methods binding)
contents.AppendLine(" static void InitRuntime()");
contents.AppendLine(" {");
@@ -1558,6 +1618,18 @@ namespace Flax.Build.Bindings
{
contents.AppendLine($" ADD_INTERNAL_CALL(\"{classTypeNameManagedInternalCall}::Internal_{functionInfo.UniqueName}\", &{functionInfo.UniqueName});");
}
if (hasInterface)
{
foreach (var interfaceInfo in classInfo.Interfaces)
{
if (interfaceInfo.Access != AccessLevel.Public)
continue;
foreach (var functionInfo in interfaceInfo.Functions)
{
contents.AppendLine($" ADD_INTERNAL_CALL(\"{classTypeNameManagedInternalCall}::Internal_{functionInfo.UniqueName}\", &{functionInfo.UniqueName});");
}
}
}
}
GenerateCppClassInitRuntime?.Invoke(buildData, classInfo, contents);
@@ -1811,29 +1883,102 @@ namespace Flax.Build.Bindings
{
var interfaceTypeNameNative = interfaceInfo.FullNameNative;
var interfaceTypeNameManaged = interfaceInfo.FullNameManaged;
var interfaceTypeNameManagedInternalCall = interfaceTypeNameManaged.Replace('+', '/');
var interfaceTypeNameInternal = interfaceInfo.NativeName;
if (interfaceInfo.Parent != null && !(interfaceInfo.Parent is FileInfo))
interfaceTypeNameInternal = interfaceInfo.Parent.FullNameNative + '_' + interfaceTypeNameInternal;
// Wrapper interface implement to invoke scripting if inherited in C# or VS
contents.AppendLine();
contents.AppendFormat("class {0}Wrapper : public ", interfaceTypeNameInternal).Append(interfaceTypeNameNative).AppendLine();
contents.Append('{').AppendLine();
contents.AppendLine("public:");
contents.AppendLine(" ScriptingObject* Object;");
foreach (var functionInfo in interfaceInfo.Functions)
{
if (!functionInfo.IsVirtual)
continue;
contents.AppendFormat(" {0} {1}(", functionInfo.ReturnType, functionInfo.Name);
var separator = false;
for (var i = 0; i < functionInfo.Parameters.Count; i++)
{
var parameterInfo = functionInfo.Parameters[i];
if (separator)
contents.Append(", ");
separator = true;
contents.Append(parameterInfo.Type).Append(' ').Append(parameterInfo.Name);
}
contents.Append(") override").AppendLine();
contents.AppendLine(" {");
// TODO: try to use ScriptVTable for interfaces implementation in scripting to call proper function instead of manually check at runtime
if (functionInfo.Parameters.Count != 0)
{
contents.AppendLine($" Variant parameters[{functionInfo.Parameters.Count}];");
for (var i = 0; i < functionInfo.Parameters.Count; i++)
{
var parameterInfo = functionInfo.Parameters[i];
contents.AppendLine($" parameters[{i}] = {GenerateCppWrapperNativeToVariant(buildData, parameterInfo.Type, interfaceInfo, parameterInfo.Name)};");
}
}
else
{
contents.AppendLine(" Variant* parameters = nullptr;");
}
contents.AppendLine(" auto typeHandle = Object->GetTypeHandle();");
contents.AppendLine(" while (typeHandle)");
contents.AppendLine(" {");
contents.AppendLine($" auto method = typeHandle.Module->FindMethod(typeHandle, \"{functionInfo.Name}\", {functionInfo.Parameters.Count});");
contents.AppendLine(" if (method)");
contents.AppendLine(" {");
contents.AppendLine(" Variant __result;");
contents.AppendLine($" typeHandle.Module->InvokeMethod(method, Object, Span<Variant>(parameters, {functionInfo.Parameters.Count}), __result);");
if (functionInfo.ReturnType.IsVoid)
contents.AppendLine(" return;");
else
contents.AppendLine($" return {GenerateCppWrapperVariantToNative(buildData, functionInfo.ReturnType, interfaceInfo, "__result")};");
contents.AppendLine(" }");
contents.AppendLine(" typeHandle = typeHandle.GetType().GetBaseType();");
contents.AppendLine(" }");
GenerateCppReturn(buildData, contents, " ", functionInfo.ReturnType);
contents.AppendLine(" }");
}
if (interfaceInfo.Name == "ISerializable")
{
// TODO: how to handle other interfaces that have some abstract native methods? maybe NativeOnly tag on interface? do it right and remove this hack
contents.AppendLine(" void Serialize(SerializeStream& stream, const void* otherObj) override {} void Deserialize(DeserializeStream& stream, ISerializeModifier* modifier) override {}");
}
contents.Append('}').Append(';').AppendLine();
contents.AppendLine();
contents.AppendFormat("class {0}Internal", interfaceTypeNameInternal).AppendLine();
contents.Append('{').AppendLine();
contents.AppendLine("public:");
GenerateCppClassInternals?.Invoke(buildData, interfaceInfo, contents);
// Virtual methods overrides
var setupScriptVTable = GenerateCppScriptVTable(buildData, contents, interfaceInfo);
// Runtime initialization (internal methods binding)
contents.AppendLine(" static void InitRuntime()");
contents.AppendLine(" {");
GenerateCppClassInitRuntime?.Invoke(buildData, interfaceInfo, contents);
contents.AppendLine(" }").AppendLine();
// Interface implementation wrapper accessor for scripting types
contents.AppendLine(" static void* GetInterfaceWrapper(ScriptingObject* obj)");
contents.AppendLine(" {");
contents.AppendLine($" auto wrapper = New<{interfaceTypeNameInternal}Wrapper>();");
contents.AppendLine(" wrapper->Object = obj;");
contents.AppendLine(" return wrapper;");
contents.AppendLine(" }");
contents.Append('}').Append(';').AppendLine();
contents.AppendLine();
// Type initializer
contents.Append($"ScriptingTypeInitializer {interfaceTypeNameNative}::TypeInitializer((BinaryModule*)GetBinaryModule{moduleInfo.Name}(), ");
contents.Append($"StringAnsiView(\"{interfaceTypeNameManaged}\", {interfaceTypeNameManaged.Length}), ");
contents.Append($"&{interfaceTypeNameInternal}Internal::InitRuntime");
contents.Append(");");
contents.Append($"StringAnsiView(\"{interfaceTypeNameManaged}\", {interfaceTypeNameManaged.Length}), &{interfaceTypeNameInternal}Internal::InitRuntime,");
contents.Append(setupScriptVTable).Append($", &{interfaceTypeNameInternal}Internal::GetInterfaceWrapper").Append(");");
contents.AppendLine();
// Nested types

View File

@@ -243,6 +243,8 @@ namespace Flax.Build.Bindings
classInfo.Functions.Add(functionInfo);
else if (context.ScopeInfo is StructureInfo structureInfo)
structureInfo.Functions.Add(functionInfo);
else if (context.ScopeInfo is InterfaceInfo interfaceInfo)
interfaceInfo.Functions.Add(functionInfo);
else
throw new Exception($"Not supported free-function {functionInfo.Name} at line {tokenizer.CurrentLine}. Place it in the class to use API bindings for it.");
}

View File

@@ -10,7 +10,7 @@ namespace Flax.Build.Bindings
/// <summary>
/// The native class information for bindings generator.
/// </summary>
public class ClassInfo : ClassStructInfo
public class ClassInfo : VirtualClassInfo
{
private static readonly HashSet<string> InBuildScriptingObjectTypes = new HashSet<string>
{
@@ -31,13 +31,10 @@ namespace Flax.Build.Bindings
public bool IsAutoSerialization;
public bool NoSpawn;
public bool NoConstructor;
public List<FunctionInfo> Functions = new List<FunctionInfo>();
public List<PropertyInfo> Properties = new List<PropertyInfo>();
public List<FieldInfo> Fields = new List<FieldInfo>();
public List<EventInfo> Events = new List<EventInfo>();
internal HashSet<string> UniqueFunctionNames;
private bool _isScriptingObject;
private int _scriptVTableSize = -1;
private int _scriptVTableOffset;
@@ -132,21 +129,6 @@ namespace Flax.Build.Bindings
if (propertyInfo.Setter != null)
ProcessAndValidate(propertyInfo.Setter);
}
foreach (var functionInfo in Functions)
ProcessAndValidate(functionInfo);
}
private void ProcessAndValidate(FunctionInfo functionInfo)
{
// Ensure that methods have unique names for bindings
if (UniqueFunctionNames == null)
UniqueFunctionNames = new HashSet<string>();
int idx = 1;
functionInfo.UniqueName = functionInfo.Name;
while (UniqueFunctionNames.Contains(functionInfo.UniqueName))
functionInfo.UniqueName = functionInfo.Name + idx++;
UniqueFunctionNames.Add(functionInfo.UniqueName);
}
public override void Write(BinaryWriter writer)
@@ -158,7 +140,6 @@ namespace Flax.Build.Bindings
writer.Write(IsAutoSerialization);
writer.Write(NoSpawn);
writer.Write(NoConstructor);
BindingsGenerator.Write(writer, Functions);
BindingsGenerator.Write(writer, Properties);
BindingsGenerator.Write(writer, Fields);
BindingsGenerator.Write(writer, Events);
@@ -175,7 +156,6 @@ namespace Flax.Build.Bindings
IsAutoSerialization = reader.ReadBoolean();
NoSpawn = reader.ReadBoolean();
NoConstructor = reader.ReadBoolean();
Functions = BindingsGenerator.Read(reader, Functions);
Properties = BindingsGenerator.Read(reader, Properties);
Fields = BindingsGenerator.Read(reader, Fields);
Events = BindingsGenerator.Read(reader, Events);
@@ -183,25 +163,51 @@ namespace Flax.Build.Bindings
base.Read(reader);
}
public int GetScriptVTableSize(Builder.BuildData buildData, out int offset)
public override int GetScriptVTableSize(out int offset)
{
if (_scriptVTableSize == -1)
{
if (BaseType is ClassInfo baseApiTypeInfo)
{
_scriptVTableOffset = baseApiTypeInfo.GetScriptVTableSize(buildData, out _);
_scriptVTableOffset = baseApiTypeInfo.GetScriptVTableSize(out _);
}
if (Interfaces != null)
{
foreach (var interfaceInfo in Interfaces)
{
if (interfaceInfo.Access != AccessLevel.Public)
continue;
_scriptVTableOffset += interfaceInfo.GetScriptVTableSize(out _);
}
}
_scriptVTableSize = _scriptVTableOffset + Functions.Count(x => x.IsVirtual);
if (IsSealed)
{
// Skip vtables for sealed classes
_scriptVTableSize = _scriptVTableOffset = 0;
}
}
offset = _scriptVTableOffset;
return _scriptVTableSize;
}
public override void AddChild(ApiTypeInfo apiTypeInfo)
public override int GetScriptVTableOffset(VirtualClassInfo classInfo)
{
apiTypeInfo.Namespace = null;
base.AddChild(apiTypeInfo);
if (classInfo == BaseType)
return 0;
if (Interfaces != null)
{
var offset = BaseType is ClassInfo baseApiTypeInfo ? baseApiTypeInfo.GetScriptVTableSize(out _) : 0;
foreach (var interfaceInfo in Interfaces)
{
if (interfaceInfo.Access != AccessLevel.Public)
continue;
if (interfaceInfo == classInfo)
return offset;
offset += interfaceInfo.GetScriptVTableSize(out _);
}
}
throw new Exception($"Cannot get Script VTable offset for {classInfo} that is not part of {this}");
}
public override string ToString()

View File

@@ -66,4 +66,59 @@ namespace Flax.Build.Bindings
base.Read(reader);
}
}
/// <summary>
/// The native class or interface information for bindings generator that contains virtual functions.
/// </summary>
public abstract class VirtualClassInfo : ClassStructInfo
{
public List<FunctionInfo> Functions = new List<FunctionInfo>();
internal HashSet<string> UniqueFunctionNames;
public override void Init(Builder.BuildData buildData)
{
base.Init(buildData);
foreach (var functionInfo in Functions)
ProcessAndValidate(functionInfo);
}
protected void ProcessAndValidate(FunctionInfo functionInfo)
{
// Ensure that methods have unique names for bindings
if (UniqueFunctionNames == null)
UniqueFunctionNames = new HashSet<string>();
int idx = 1;
functionInfo.UniqueName = functionInfo.Name;
while (UniqueFunctionNames.Contains(functionInfo.UniqueName))
functionInfo.UniqueName = functionInfo.Name + idx++;
UniqueFunctionNames.Add(functionInfo.UniqueName);
}
public abstract int GetScriptVTableSize(out int offset);
public abstract int GetScriptVTableOffset(VirtualClassInfo classInfo);
public override void Write(BinaryWriter writer)
{
BindingsGenerator.Write(writer, Functions);
base.Write(writer);
}
public override void Read(BinaryReader reader)
{
Functions = BindingsGenerator.Read(reader, Functions);
base.Read(reader);
}
public override void AddChild(ApiTypeInfo apiTypeInfo)
{
apiTypeInfo.Namespace = null;
base.AddChild(apiTypeInfo);
}
}
}

View File

@@ -1,19 +1,36 @@
// Copyright (c) 2012-2021 Wojciech Figat. All rights reserved.
using System;
using System.Linq;
namespace Flax.Build.Bindings
{
/// <summary>
/// The native class/structure interface information for bindings generator.
/// </summary>
public class InterfaceInfo : ClassStructInfo
public class InterfaceInfo : VirtualClassInfo
{
public override int GetScriptVTableSize(out int offset)
{
offset = 0;
return Functions.Count(x => x.IsVirtual);
}
public override int GetScriptVTableOffset(VirtualClassInfo classInfo)
{
return 0;
}
public override bool IsInterface => true;
public override void AddChild(ApiTypeInfo apiTypeInfo)
public override void Init(Builder.BuildData buildData)
{
apiTypeInfo.Namespace = null;
base.Init(buildData);
base.AddChild(apiTypeInfo);
if (BaseType != null)
throw new Exception(string.Format("Interface {0} cannot inherit from {1}.", FullNameNative, BaseType));
if (Interfaces != null && Interfaces.Count != 0)
throw new Exception(string.Format("Interface {0} cannot inherit from {1}.", FullNameNative, "interfaces"));
}
public override string ToString()

View File

@@ -20,7 +20,7 @@ namespace Flax.Build.Plugins
BindingsGenerator.CppScriptObjectVirtualWrapperMethodsPostfixes.Add("_VisualScriptWrapper");
}
private void OnGenerateCppScriptWrapperFunction(Builder.BuildData buildData, ClassInfo classInfo, FunctionInfo functionInfo, int scriptVTableSize, int scriptVTableIndex, StringBuilder contents)
private void OnGenerateCppScriptWrapperFunction(Builder.BuildData buildData, VirtualClassInfo classInfo, FunctionInfo functionInfo, int scriptVTableSize, int scriptVTableIndex, StringBuilder contents)
{
// Generate C++ wrapper function to invoke Visual Script instead of overridden native function (with support for base method callback)
@@ -43,13 +43,29 @@ namespace Flax.Build.Plugins
contents.Append(')');
contents.AppendLine();
contents.AppendLine(" {");
contents.AppendLine($" auto object = ({classInfo.NativeName}*)this;");
string scriptVTableOffset;
if (classInfo.IsInterface)
{
contents.AppendLine($" auto object = ScriptingObject::FromInterface(this, {classInfo.NativeName}::TypeInitializer);");
contents.AppendLine(" if (object == nullptr)");
contents.AppendLine(" {");
contents.AppendLine($" LOG(Error, \"Failed to cast interface {{0}} to scripting object\", TEXT(\"{classInfo.Name}\"));");
BindingsGenerator.GenerateCppReturn(buildData, contents, " ", functionInfo.ReturnType);
contents.AppendLine(" }");
contents.AppendLine($" const int32 scriptVTableOffset = {scriptVTableIndex} + object->GetType().GetInterface({classInfo.NativeName}::TypeInitializer)->ScriptVTableOffset;");
scriptVTableOffset = "scriptVTableOffset";
}
else
{
contents.AppendLine($" auto object = ({classInfo.NativeName}*)this;");
scriptVTableOffset = scriptVTableIndex.ToString();
}
contents.AppendLine(" static THREADLOCAL void* WrapperCallInstance = nullptr;");
contents.AppendLine(" if (WrapperCallInstance == object)");
contents.AppendLine(" {");
contents.AppendLine(" // Prevent stack overflow by calling base method");
contents.AppendLine(" const auto scriptVTableBase = object->GetType().Script.ScriptVTableBase;");
contents.Append($" return (this->**({functionInfo.UniqueName}_Internal_Signature*)&scriptVTableBase[{scriptVTableIndex} + 2])(");
contents.Append($" return (this->**({functionInfo.UniqueName}_Internal_Signature*)&scriptVTableBase[{scriptVTableOffset} + 2])(");
separator = false;
for (var i = 0; i < functionInfo.Parameters.Count; i++)
{
@@ -62,7 +78,7 @@ namespace Flax.Build.Plugins
contents.AppendLine(");");
contents.AppendLine(" }");
contents.AppendLine(" auto scriptVTable = (VisualScript::Method**)object->GetType().Script.ScriptVTable;");
contents.AppendLine($" ASSERT(scriptVTable && scriptVTable[{scriptVTableIndex}]);");
contents.AppendLine($" ASSERT(scriptVTable && scriptVTable[{scriptVTableOffset}]);");
if (functionInfo.Parameters.Count != 0)
{
@@ -80,7 +96,7 @@ namespace Flax.Build.Plugins
contents.AppendLine(" auto prevWrapperCallInstance = WrapperCallInstance;");
contents.AppendLine(" WrapperCallInstance = object;");
contents.AppendLine($" auto __result = VisualScripting::Invoke(scriptVTable[{scriptVTableIndex}], object, Span<Variant>(parameters, {functionInfo.Parameters.Count}));");
contents.AppendLine($" auto __result = VisualScripting::Invoke(scriptVTable[{scriptVTableOffset}], object, Span<Variant>(parameters, {functionInfo.Parameters.Count}));");
contents.AppendLine(" WrapperCallInstance = prevWrapperCallInstance;");
if (!functionInfo.ReturnType.IsVoid)