From c3c0a4ef0d79cee99348a65bb8110532e184830c Mon Sep 17 00:00:00 2001 From: Wojtek Figat Date: Mon, 4 Oct 2021 12:22:28 +0200 Subject: [PATCH] Add support for interfaces in scripting API (cross language support C++/C#/VS) --- Source/Engine/Content/Assets/VisualScript.cpp | 47 +- Source/Engine/Scripting/BinaryModule.cpp | 475 +++++++++++++----- Source/Engine/Scripting/BinaryModule.h | 4 +- Source/Engine/Scripting/ManagedCLR/MClass.h | 12 - Source/Engine/Scripting/Object.cs | 14 + Source/Engine/Scripting/ScriptingObject.cpp | 81 +++ Source/Engine/Scripting/ScriptingObject.h | 2 + Source/Engine/Scripting/ScriptingType.h | 28 +- .../Bindings/BindingsGenerator.CSharp.cs | 194 ++++++- .../Bindings/BindingsGenerator.Cache.cs | 2 +- .../Bindings/BindingsGenerator.Cpp.cs | 367 ++++++++++---- .../Flax.Build/Bindings/BindingsGenerator.cs | 2 + Source/Tools/Flax.Build/Bindings/ClassInfo.cs | 60 ++- .../Flax.Build/Bindings/ClassStructInfo.cs | 55 ++ .../Flax.Build/Bindings/InterfaceInfo.cs | 25 +- .../Build/Plugins/VisualScriptingPlugin.cs | 26 +- 16 files changed, 1039 insertions(+), 355 deletions(-) diff --git a/Source/Engine/Content/Assets/VisualScript.cpp b/Source/Engine/Content/Assets/VisualScript.cpp index fe3279856..8649a3dd1 100644 --- a/Source/Engine/Content/Assets/VisualScript.cpp +++ b/Source/Engine/Content/Assets/VisualScript.cpp @@ -1263,8 +1263,7 @@ Asset::LoadResult VisualScript::load() if (visualScriptType.Script.ScriptVTable) { // Override object vtable with hacked one that has Visual Script functions calls - ASSERT(visualScriptType.Script.VTable); - *(void**)object = visualScriptType.Script.VTable; + visualScriptType.HackObjectVTable(object, visualScriptType.BaseTypeHandle, 1); } } const int32 oldCount = _oldParamsLayout.Count(); @@ -1407,17 +1406,7 @@ void VisualScript::CacheScriptingType() type.ManagedClass = baseType.GetType().ManagedClass; // Create custom vtable for this class (build out of the wrapper C++ methods that call Visual Script graph) - // Call setup for all class starting from the first native type (first that uses virtual calls will allocate table of a proper size, further base types will just add own methods) - for (ScriptingTypeHandle e = nativeType; e;) - { - const ScriptingType& eType = e.GetType(); - if (eType.Script.SetupScriptVTable) - { - ASSERT(eType.ManagedClass); - eType.Script.SetupScriptVTable(eType.ManagedClass, type.Script.ScriptVTable, type.Script.ScriptVTableBase); - } - e = eType.GetBaseType(); - } + type.SetupScriptVTable(nativeType); MMethod** scriptVTable = (MMethod**)type.Script.ScriptVTable; while (scriptVTable && *scriptVTable) { @@ -1484,37 +1473,7 @@ ScriptingObject* VisualScriptingBinaryModule::VisualScriptObjectSpawn(const Scri } // Beware! Hacking vtables incoming! Undefined behaviors exploits! Low-level programming! - // What's happening here? - // We create a custom vtable for the Visual Script objects that use a native class object with virtual functions overrides. - // To make it easy to use in C++ we inject custom wrapper methods into C++ object vtable to execute Visual Script graph from them. - // Because virtual member functions calls are C++ ABI and impl-defined this is quite hard. But works. - if (visualScriptType.Script.ScriptVTable) - { - if (!visualScriptType.Script.VTable) - { - // Duplicate vtable - void** vtable = *(void***)object; - const int32 prefixSize = GetVTablePrefix(); - int32 entriesCount = 0; - while (vtable[entriesCount] && entriesCount < 200) - entriesCount++; - const int32 size = entriesCount * sizeof(void*); - visualScriptType.Script.VTable = (void**)((byte*)Platform::Allocate(prefixSize + size, 16) + prefixSize); - Platform::MemoryCopy((byte*)visualScriptType.Script.VTable - prefixSize, (byte*)vtable - prefixSize, prefixSize + size); - - // Override vtable entries by the class - for (ScriptingTypeHandle e = baseTypeHandle; e;) - { - const ScriptingType& eType = e.GetType(); - if (eType.Script.SetupScriptObjectVTable) - eType.Script.SetupScriptObjectVTable(visualScriptType.Script.ScriptVTable, visualScriptType.Script.ScriptVTableBase, visualScriptType.Script.VTable, entriesCount, 1); - e = eType.GetBaseType(); - } - } - - // Override object vtable with hacked one that has Visual Script functions calls - *(void**)object = visualScriptType.Script.VTable; - } + visualScriptType.HackObjectVTable(object, baseTypeHandle, 1); // Mark as custom scripting type object->Flags |= ObjectFlags::IsCustomScriptingType; diff --git a/Source/Engine/Scripting/BinaryModule.cpp b/Source/Engine/Scripting/BinaryModule.cpp index 61633ae9a..ef444afd9 100644 --- a/Source/Engine/Scripting/BinaryModule.cpp +++ b/Source/Engine/Scripting/BinaryModule.cpp @@ -81,6 +81,7 @@ ScriptingType::ScriptingType() { Script.Spawn = nullptr; Script.VTable = nullptr; + Script.InterfacesOffsets = nullptr; Script.ScriptVTable = nullptr; Script.ScriptVTableBase = nullptr; Script.SetupScriptVTable = nullptr; @@ -101,6 +102,7 @@ ScriptingType::ScriptingType(const StringAnsiView& fullname, BinaryModule* modul { Script.Spawn = spawn; Script.VTable = nullptr; + Script.InterfacesOffsets = nullptr; Script.ScriptVTable = nullptr; Script.ScriptVTableBase = nullptr; Script.SetupScriptVTable = setupScriptVTable; @@ -120,6 +122,7 @@ ScriptingType::ScriptingType(const StringAnsiView& fullname, BinaryModule* modul { Script.Spawn = spawn; Script.VTable = nullptr; + Script.InterfacesOffsets = nullptr; Script.ScriptVTable = nullptr; Script.ScriptVTableBase = nullptr; Script.SetupScriptVTable = setupScriptVTable; @@ -160,16 +163,19 @@ ScriptingType::ScriptingType(const StringAnsiView& fullname, BinaryModule* modul Struct.SetField = setField; } -ScriptingType::ScriptingType(const StringAnsiView& fullname, BinaryModule* module, InitRuntimeHandler initRuntime, ScriptingTypeInitializer* baseType, const InterfaceImplementation* interfaces) +ScriptingType::ScriptingType(const StringAnsiView& fullname, BinaryModule* module, InitRuntimeHandler initRuntime, SetupScriptVTableHandler setupScriptVTable, SetupScriptObjectVTableHandler setupScriptObjectVTable, GetInterfaceWrapper getInterfaceWrapper) : ManagedClass(nullptr) , Module(module) , InitRuntime(initRuntime) , Fullname(fullname) , Type(ScriptingTypes::Interface) - , BaseTypePtr(baseType) - , Interfaces(interfaces) + , BaseTypePtr(nullptr) + , Interfaces(nullptr) , Size(0) { + Interface.SetupScriptVTable = setupScriptVTable; + Interface.SetupScriptObjectVTable = setupScriptObjectVTable; + Interface.GetInterfaceWrapper = getInterfaceWrapper; } ScriptingType::ScriptingType(const ScriptingType& other) @@ -188,6 +194,7 @@ ScriptingType::ScriptingType(const ScriptingType& other) case ScriptingTypes::Script: Script.Spawn = other.Script.Spawn; Script.VTable = nullptr; + Script.InterfacesOffsets = nullptr; Script.ScriptVTable = nullptr; Script.ScriptVTableBase = nullptr; Script.SetupScriptVTable = other.Script.SetupScriptVTable; @@ -210,6 +217,9 @@ ScriptingType::ScriptingType(const ScriptingType& other) case ScriptingTypes::Enum: break; case ScriptingTypes::Interface: + Interface.SetupScriptVTable = other.Interface.SetupScriptVTable; + Interface.SetupScriptObjectVTable = other.Interface.SetupScriptObjectVTable; + Interface.GetInterfaceWrapper = other.Interface.GetInterfaceWrapper; break; default: ; } @@ -232,6 +242,8 @@ ScriptingType::ScriptingType(ScriptingType&& other) Script.Spawn = other.Script.Spawn; Script.VTable = other.Script.VTable; other.Script.VTable = nullptr; + Script.InterfacesOffsets = other.Script.InterfacesOffsets; + other.Script.InterfacesOffsets = nullptr; Script.ScriptVTable = other.Script.ScriptVTable; other.Script.ScriptVTable = nullptr; Script.ScriptVTableBase = other.Script.ScriptVTableBase; @@ -257,6 +269,9 @@ ScriptingType::ScriptingType(ScriptingType&& other) case ScriptingTypes::Enum: break; case ScriptingTypes::Interface: + Interface.SetupScriptVTable = other.Interface.SetupScriptVTable; + Interface.SetupScriptObjectVTable = other.Interface.SetupScriptObjectVTable; + Interface.GetInterfaceWrapper = other.Interface.GetInterfaceWrapper; break; default: ; } @@ -271,6 +286,7 @@ ScriptingType::~ScriptingType() Delete(Script.DefaultInstance); if (Script.VTable) Platform::Free((byte*)Script.VTable - GetVTablePrefix()); + Platform::Free(Script.InterfacesOffsets); Platform::Free(Script.ScriptVTable); Platform::Free(Script.ScriptVTableBase); break; @@ -332,6 +348,172 @@ const ScriptingType::InterfaceImplementation* ScriptingType::GetInterface(const return nullptr; } +void ScriptingType::SetupScriptVTable(ScriptingTypeHandle baseTypeHandle) +{ + // Call setup for all class starting from the first native type (first that uses virtual calls will allocate table of a proper size, further base types will just add own methods) + for (ScriptingTypeHandle e = baseTypeHandle; e;) + { + const ScriptingType& eType = e.GetType(); + + if (eType.Script.SetupScriptVTable) + { + ASSERT(eType.ManagedClass); + eType.Script.SetupScriptVTable(eType.ManagedClass, Script.ScriptVTable, Script.ScriptVTableBase); + } + + auto interfaces = eType.Interfaces; + if (interfaces && Script.ScriptVTable) + { + while (interfaces->InterfaceType) + { + auto& interfaceType = interfaces->InterfaceType->GetType(); + if (interfaceType.Interface.SetupScriptVTable) + { + ASSERT(eType.ManagedClass); + const auto scriptOffset = interfaces->ScriptVTableOffset; // Shift the script vtable for the interface implementation start + Script.ScriptVTable += scriptOffset; + Script.ScriptVTableBase += scriptOffset; + interfaceType.Interface.SetupScriptVTable(eType.ManagedClass, Script.ScriptVTable, Script.ScriptVTableBase); + Script.ScriptVTable -= scriptOffset; + Script.ScriptVTableBase -= scriptOffset; + } + interfaces++; + } + } + e = eType.GetBaseType(); + } +} + +void ScriptingType::SetupScriptObjectVTable(void* object, ScriptingTypeHandle baseTypeHandle, int32 wrapperIndex) +{ + // Analyze vtable size + void** vtable = *(void***)object; + const int32 prefixSize = GetVTablePrefix(); + int32 entriesCount = 0; + while (vtable[entriesCount] && entriesCount < 200) + entriesCount++; + + // Calculate total vtable size by adding all implemented interfaces that use virtual methods + const int32 size = entriesCount * sizeof(void*); + int32 totalSize = prefixSize + size; + int32 interfacesCount = 0; + for (ScriptingTypeHandle e = baseTypeHandle; e;) + { + const ScriptingType& eType = e.GetType(); + auto interfaces = eType.Interfaces; + if (interfaces) + { + while (interfaces->InterfaceType) + { + auto& interfaceType = interfaces->InterfaceType->GetType(); + if (interfaceType.Interface.SetupScriptObjectVTable) + { + void** vtableInterface = *(void***)((byte*)object + interfaces->VTableOffset); + int32 interfaceCount = 0; + while (vtableInterface[interfaceCount] && interfaceCount < 200) + interfaceCount++; + totalSize += prefixSize + interfaceCount * sizeof(void*); + interfacesCount++; + } + interfaces++; + } + } + e = eType.GetBaseType(); + } + + // Duplicate vtable + Script.VTable = (void**)((byte*)Platform::Allocate(totalSize, 16) + prefixSize); + Platform::MemoryCopy((byte*)Script.VTable - prefixSize, (byte*)vtable - prefixSize, prefixSize + size); + + // Override vtable entries + if (interfacesCount) + Script.InterfacesOffsets = (uint16*)Platform::Allocate(interfacesCount * sizeof(uint16*), 16); + int32 interfaceOffset = size; + interfacesCount = 0; + for (ScriptingTypeHandle e = baseTypeHandle; e;) + { + const ScriptingType& eType = e.GetType(); + + if (eType.Script.SetupScriptObjectVTable) + { + // Override vtable entries for this class + eType.Script.SetupScriptObjectVTable(Script.ScriptVTable, Script.ScriptVTableBase, Script.VTable, entriesCount, wrapperIndex); + } + + auto interfaces = eType.Interfaces; + if (interfaces) + { + while (interfaces->InterfaceType) + { + auto& interfaceType = interfaces->InterfaceType->GetType(); + if (interfaceType.Interface.SetupScriptObjectVTable) + { + // Analyze interface vtable size + void** vtableInterface = *(void***)((byte*)object + interfaces->VTableOffset); + int32 interfaceCount = 0; + while (vtableInterface[interfaceCount] && interfaceCount < 200) + interfaceCount++; + const int32 interfaceSize = interfaceCount * sizeof(void*); + + // Duplicate interface vtable + Platform::MemoryCopy((byte*)Script.VTable + interfaceOffset, (byte*)vtableInterface - prefixSize, prefixSize + interfaceSize); + + // Override interface vtable entries + const auto scriptOffset = interfaces->ScriptVTableOffset; + const auto nativeOffset = interfaceOffset + prefixSize; + void** interfaceVTable = (void**)((byte*)Script.VTable + nativeOffset); + interfaceType.Interface.SetupScriptObjectVTable(Script.ScriptVTable + scriptOffset, Script.ScriptVTableBase + scriptOffset, interfaceVTable, interfaceCount, wrapperIndex); + + Script.InterfacesOffsets[interfacesCount++] = (uint16)nativeOffset; + interfaceOffset += prefixSize + interfaceSize; + } + interfaces++; + } + } + e = eType.GetBaseType(); + } +} + +void ScriptingType::HackObjectVTable(void* object, ScriptingTypeHandle baseTypeHandle, int32 wrapperIndex) +{ + if (!Script.ScriptVTable) + return; + if (!Script.VTable) + { + // Ensure to have valid Script VTable hacked + SetupScriptObjectVTable(object, baseTypeHandle, wrapperIndex); + } + + // Override object vtable with hacked one that has calls to overriden scripting functions + *(void**)object = Script.VTable; + + if (Script.InterfacesOffsets) + { + // Override vtables for interfaces + int32 interfacesCount = 0; + for (ScriptingTypeHandle e = baseTypeHandle; e;) + { + const ScriptingType& eType = e.GetType(); + auto interfaces = eType.Interfaces; + if (interfaces) + { + while (interfaces->InterfaceType) + { + auto& interfaceType = interfaces->InterfaceType->GetType(); + if (interfaceType.Interface.SetupScriptObjectVTable) + { + void** interfaceVTable = (void**)((byte*)Script.VTable + Script.InterfacesOffsets[interfacesCount++]); + *(void**)((byte*)object + interfaces->VTableOffset) = interfaceVTable; + interfacesCount++; + } + interfaces++; + } + } + e = eType.GetBaseType(); + } + } +} + String ScriptingType::ToString() const { return String(Fullname.Get(), Fullname.Length()); @@ -382,12 +564,12 @@ ScriptingTypeInitializer::ScriptingTypeInitializer(BinaryModule* module, const S module->TypeNameToTypeIndex[fullname] = TypeIndex; } -ScriptingTypeInitializer::ScriptingTypeInitializer(BinaryModule* module, const StringAnsiView& fullname, ScriptingType::InitRuntimeHandler initRuntime, ScriptingTypeInitializer* baseType, const ScriptingType::InterfaceImplementation* interfaces) +ScriptingTypeInitializer::ScriptingTypeInitializer(BinaryModule* module, const StringAnsiView& fullname, ScriptingType::InitRuntimeHandler initRuntime, ScriptingType::SetupScriptVTableHandler setupScriptVTable, ScriptingType::SetupScriptObjectVTableHandler setupScriptObjectVTable, ScriptingType::GetInterfaceWrapper getInterfaceWrapper) : ScriptingTypeHandle(module, module->Types.Count()) { // Interface module->Types.AddUninitialized(); - new(module->Types.Get() + TypeIndex)ScriptingType(fullname, module, initRuntime, baseType, interfaces); + new(module->Types.Get() + TypeIndex)ScriptingType(fullname, module, initRuntime, setupScriptVTable, setupScriptObjectVTable, getInterfaceWrapper); #if BUILD_DEBUG if (module->TypeNameToTypeIndex.ContainsKey(fullname)) { @@ -508,39 +690,7 @@ ScriptingObject* ManagedBinaryModule::ManagedObjectSpawn(const ScriptingObjectSp } // Beware! Hacking vtables incoming! Undefined behaviors exploits! Low-level programming! - // What's happening here? - // We create a custom vtable for the C# objects that use a native class object with virtual functions overrides. - // To make it easy to use in C++ we inject custom wrapper methods into C++ object vtable to call C# code from them. - // Because virtual member functions calls are C++ ABI and impl-defined this is quite hard. But works. - if (managedType.Script.ScriptVTable) - { - if (!managedType.Script.VTable) - { - // Duplicate vtable - void** vtable = *(void***)object; - const int32 prefixSize = GetVTablePrefix(); - int32 entriesCount = 0; - while (vtable[entriesCount] && entriesCount < 200) - entriesCount++; - const int32 size = entriesCount * sizeof(void*); - managedType.Script.VTable = (void**)((byte*)Platform::Allocate(prefixSize + size, 16) + prefixSize); - Platform::MemoryCopy((byte*)managedType.Script.VTable - prefixSize, (byte*)vtable - prefixSize, prefixSize + size); - - // Override vtable entries by the class - for (ScriptingTypeHandle e = nativeTypeHandle; e;) - { - const ScriptingType& eType = e.GetType(); - if (eType.Script.SetupScriptObjectVTable) - { - eType.Script.SetupScriptObjectVTable(managedType.Script.ScriptVTable, managedType.Script.ScriptVTableBase, managedType.Script.VTable, entriesCount, 0); - } - e = eType.GetBaseType(); - } - } - - // Override object vtable with hacked one that has C# functions calls - *(void**)object = managedType.Script.VTable; - } + managedType.HackObjectVTable(object, nativeTypeHandle, 0); // Mark as managed type object->Flags |= ObjectFlags::IsManagedType; @@ -618,6 +768,20 @@ ManagedBinaryModule* ManagedBinaryModule::FindModule(MonoClass* klass) return module; } +ScriptingTypeHandle ManagedBinaryModule::FindType(MonoClass* klass) +{ + auto typeModule = FindModule(klass); + if (typeModule) + { + int32 typeIndex; + if (typeModule->ClassToTypeIndex.TryGet(klass, typeIndex)) + { + return ScriptingTypeHandle(typeModule, typeIndex); + } + } + return ScriptingTypeHandle(); +} + void ManagedBinaryModule::OnLoading(MAssembly* assembly) { PROFILE_CPU(); @@ -672,13 +836,9 @@ void ManagedBinaryModule::OnLoaded(MAssembly* assembly) for (auto i = classes.Begin(); i.IsNotEnd(); ++i) { MClass* mclass = i->Value; - const MString& typeName = mclass->GetFullName(); // Check if C# class inherits from C++ object class it has no C++ representation - if ( - TypeNameToTypeIndex.Find(typeName) != TypeNameToTypeIndex.End() || - mclass->IsStatic() || - mclass->IsAbstract() || + if (mclass->IsStatic() || mclass->IsInterface() || !mclass->IsSubClassOf(scriptingObjectType) ) @@ -686,96 +846,151 @@ void ManagedBinaryModule::OnLoaded(MAssembly* assembly) continue; } - // Find first native base C++ class of this C# class - MClass* baseClass = mclass->GetBaseClass(); - ScriptingTypeHandle nativeType; - while (baseClass) + InitType(mclass); + } + } +} + +void ManagedBinaryModule::InitType(MClass* mclass) +{ + // Skip if already initialized + const MString& typeName = mclass->GetFullName(); + if (TypeNameToTypeIndex.ContainsKey(typeName)) + return; + + // Find first native base C++ class of this C# class + MClass* baseClass = nullptr; + MonoClass* baseKlass = mono_class_get_parent(mclass->GetNative()); + MonoImage* baseKlassImage = mono_class_get_image(baseKlass); + ScriptingTypeHandle baseType; + auto& modules = GetModules(); + for (int32 i = 0; i < modules.Count(); i++) + { + auto e = dynamic_cast(modules[i]); + if (e && e->Assembly->GetMonoImage() == baseKlassImage) + { + baseType.Module = e; + baseClass = e->Assembly->GetClass(baseKlass); + break; + } + } + if (!baseClass) + { + LOG(Error, "Missing base class for managed class {0} from assembly {1}.", String(typeName), Assembly->ToString()); + return; + } + if (baseType.Module == this) + InitType(baseClass); // Ensure base is initialized before + baseType.Module->TypeNameToTypeIndex.TryGet(baseClass->GetFullName(), *(int32*)&baseType.TypeIndex); + if (!baseType) + { + LOG(Error, "Missing base class for managed class {0} from assembly {1}.", String(typeName), Assembly->ToString()); + return; + } + ScriptingTypeHandle nativeType = baseType; + while (true) + { + auto& type = nativeType.GetType(); + if (type.Script.Spawn != &ManagedObjectSpawn) + break; + nativeType = type.GetBaseType(); + if (!nativeType) + { + LOG(Error, "Missing base class for managed class {0} from assembly {1}.", String(typeName), Assembly->ToString()); + return; + } + } + + // Scripting Type has Fullname span pointing to the string in memory (usually static data) so store the name in assembly + char* typeNameData = (char*)Allocator::Allocate(typeName.Length() + 1); + Platform::MemoryCopy(typeNameData, typeName.Get(), typeName.Length()); + typeNameData[typeName.Length()] = 0; + _managedMemoryBlocks.Add(typeNameData); + + // Initialize scripting interfaces implemented in C# + MonoClass* interfaceKlass; + void* interfaceIt = nullptr; + int32 interfacesCount = 0; + MonoClass* klass = mclass->GetNative(); + while (interfaceKlass = mono_class_get_interfaces(klass, &interfaceIt)) + { + const ScriptingTypeHandle interfaceType = FindType(interfaceKlass); + if (interfaceType) + interfacesCount++; + } + ScriptingType::InterfaceImplementation* interfaces = nullptr; + if (interfacesCount != 0) + { + interfaces = (ScriptingType::InterfaceImplementation*)Allocator::Allocate((interfacesCount + 1) * sizeof(ScriptingType::InterfaceImplementation)); + interfacesCount = 0; + interfaceIt = nullptr; + while (interfaceKlass = mono_class_get_interfaces(klass, &interfaceIt)) + { + const ScriptingTypeHandle interfaceTypeHandle = FindType(interfaceKlass); + if (!interfaceTypeHandle) + continue; + auto& interface = interfaces[interfacesCount++]; + auto ptr = (ScriptingTypeHandle*)Allocator::Allocate(sizeof(ScriptingTypeHandle)); + *ptr = interfaceTypeHandle; + _managedMemoryBlocks.Add(ptr); + interface.InterfaceType = ptr; + interface.VTableOffset = 0; + interface.ScriptVTableOffset = 0; + interface.IsNative = false; + } + Platform::MemoryClear(interfaces + interfacesCount, sizeof(ScriptingType::InterfaceImplementation)); + _managedMemoryBlocks.Add(interfaces); + } + + // Create scripting type descriptor for managed-only type based on the native base class + const int32 typeIndex = Types.Count(); + Types.AddUninitialized(); + new(Types.Get() + Types.Count() - 1)ScriptingType(typeName, this, baseType.GetType().Size, ScriptingType::DefaultInitRuntime, ManagedObjectSpawn, baseType, nullptr, nullptr, interfaces); + TypeNameToTypeIndex[typeName] = typeIndex; + auto& type = Types[typeIndex]; + type.ManagedClass = mclass; + + // Register Mono class + ASSERT(!ClassToTypeIndex.ContainsKey(klass)); + ClassToTypeIndex[klass] = typeIndex; + + // Create managed vtable for this class (build out of the wrapper C++ methods that call C# methods) + type.SetupScriptVTable(nativeType); + MMethod** scriptVTable = (MMethod**)type.Script.ScriptVTable; + while (scriptVTable && *scriptVTable) + { + const MMethod* referenceMethod = *scriptVTable; + + // Find that method overriden in C# class (the current or one of the base classes in C#) + MMethod* method = ::FindMethod(mclass, referenceMethod); + if (method == nullptr) + { + // Check base classes (skip native class) + baseClass = mclass->GetBaseClass(); + MClass* nativeBaseClass = nativeType.GetType().ManagedClass; + while (baseClass && baseClass != nativeBaseClass && method == nullptr) { - int32 typeIndex; - BinaryModule* baseClassModule = GetModule(baseClass->GetAssembly()); - ASSERT(baseClassModule); - if (baseClassModule->TypeNameToTypeIndex.TryGet(baseClass->GetFullName(), typeIndex)) + method = ::FindMethod(baseClass, referenceMethod); + + // Special case if method was found but the base class uses generic arguments + if (method && baseClass->IsGeneric()) { - nativeType = ScriptingTypeHandle(baseClassModule, typeIndex); - if (nativeType.GetType().Script.Spawn != &ManagedObjectSpawn) - break; + // TODO: encapsulate it into MClass to support inflated methods + auto parentClass = mono_class_get_parent(mclass->GetNative()); + auto parentMethod = mono_class_get_method_from_name(parentClass, referenceMethod->GetName().Get(), 0); + auto inflatedMethod = mono_class_inflate_generic_method(parentMethod, nullptr); + method = New(inflatedMethod, baseClass); } + baseClass = baseClass->GetBaseClass(); } - if (!nativeType) - { - LOG(Error, "Missing native base class for managed class {0} from assembly {1}.", String(typeName), assembly->ToString()); - continue; - } - - // Scripting Type has Fullname span pointing to the string in memory (usually static data) so store the name in assembly - char* typeNameData = (char*)Allocator::Allocate(typeName.Length() + 1); - Platform::MemoryCopy(typeNameData, typeName.Get(), typeName.Length()); - typeNameData[typeName.Length()] = 0; - _managedTypesNames.Add(typeNameData); - - // Create scripting type descriptor for managed-only type based on the native base class - const int32 typeIndex = Types.Count(); - Types.AddUninitialized(); - new(Types.Get() + Types.Count() - 1)ScriptingType(typeName, this, nativeType.GetType().Size, ScriptingType::DefaultInitRuntime, ManagedObjectSpawn, nativeType); - TypeNameToTypeIndex[typeName] = typeIndex; - auto& type = Types[typeIndex]; - type.ManagedClass = mclass; - - // Register Mono class - MonoClass* klass = mclass->GetNative(); - ASSERT(!ClassToTypeIndex.ContainsKey(klass)); - ClassToTypeIndex[klass] = typeIndex; - - // Create managed vtable for this class (build out of the wrapper C++ methods that call C# methods) - // Call setup for all class starting from the first native type (first that uses virtual calls will allocate table of a proper size, further base types will just add own methods) - for (ScriptingTypeHandle e = nativeType; e;) - { - const ScriptingType& eType = e.GetType(); - if (eType.Script.SetupScriptVTable) - { - ASSERT(eType.ManagedClass); - eType.Script.SetupScriptVTable(eType.ManagedClass, type.Script.ScriptVTable, type.Script.ScriptVTableBase); - } - e = eType.GetBaseType(); - } - MMethod** scriptVTable = (MMethod**)type.Script.ScriptVTable; - while (scriptVTable && *scriptVTable) - { - const MMethod* referenceMethod = *scriptVTable; - - // Find that method overriden in C# class (the current or one of the base classes in C#) - MMethod* method = ::FindMethod(mclass, referenceMethod); - if (method == nullptr) - { - // Check base classes (skip native class) - baseClass = mclass->GetBaseClass(); - MClass* nativeBaseClass = nativeType.GetType().ManagedClass; - while (baseClass && baseClass != nativeBaseClass && method == nullptr) - { - method = ::FindMethod(baseClass, referenceMethod); - - // Special case if method was found but the base class uses generic arguments - if (method && baseClass->IsGeneric()) - { - // TODO: encapsulate it into MClass to support inflated methods - auto parentClass = mono_class_get_parent(mclass->GetNative()); - auto parentMethod = mono_class_get_method_from_name(parentClass, referenceMethod->GetName().Get(), 0); - auto inflatedMethod = mono_class_inflate_generic_method(parentMethod, nullptr); - method = New(inflatedMethod, baseClass); - } - - baseClass = baseClass->GetBaseClass(); - } - } - - // Set the method to call (null entry marks unused entries that won't use C# wrapper calls) - *scriptVTable = method; - - // Move to the next entry (table is null terminated) - scriptVTable++; - } } + + // Set the method to call (null entry marks unused entries that won't use C# wrapper calls) + *scriptVTable = method; + + // Move to the next entry (table is null terminated) + scriptVTable++; } } @@ -791,9 +1006,9 @@ void ManagedBinaryModule::OnUnloading(MAssembly* assembly) TypeNameToTypeIndex.Remove(typeName); } Types.Resize(_firstManagedTypeIndex); - for (int32 i = 0; i < _managedTypesNames.Count(); i++) - Allocator::Free(_managedTypesNames[i]); - _managedTypesNames.Clear(); + for (int32 i = 0; i < _managedMemoryBlocks.Count(); i++) + Allocator::Free(_managedMemoryBlocks[i]); + _managedMemoryBlocks.Clear(); // Clear managed types information for (ScriptingType& type : Types) diff --git a/Source/Engine/Scripting/BinaryModule.h b/Source/Engine/Scripting/BinaryModule.h index d76d46c5a..725400d38 100644 --- a/Source/Engine/Scripting/BinaryModule.h +++ b/Source/Engine/Scripting/BinaryModule.h @@ -261,7 +261,7 @@ public: private: int32 _firstManagedTypeIndex; - Array _managedTypesNames; + Array _managedMemoryBlocks; public: @@ -298,11 +298,13 @@ public: static ScriptingObject* ManagedObjectSpawn(const ScriptingObjectSpawnParams& params); static MMethod* FindMethod(MClass* mclass, const ScriptingTypeMethodSignature& signature); static ManagedBinaryModule* FindModule(MonoClass* klass); + static ScriptingTypeHandle FindType(MonoClass* klass); private: void OnLoading(MAssembly* assembly); void OnLoaded(MAssembly* assembly); + void InitType(MClass* mclass); void OnUnloading(MAssembly* assembly); public: diff --git a/Source/Engine/Scripting/ManagedCLR/MClass.h b/Source/Engine/Scripting/ManagedCLR/MClass.h index e4d0ba74c..255abd84b 100644 --- a/Source/Engine/Scripting/ManagedCLR/MClass.h +++ b/Source/Engine/Scripting/ManagedCLR/MClass.h @@ -58,7 +58,6 @@ public: /// /// Gets the parent assembly. /// - /// The assembly. const MAssembly* GetAssembly() const { return _assembly; @@ -67,7 +66,6 @@ public: /// /// Gets the full name of the class (namespace and typename). /// - /// The fullname. FORCE_INLINE const MString& GetFullName() const { return _fullname; @@ -78,7 +76,6 @@ public: /// /// Gets the Mono class handle. /// - /// The Mono class. MonoClass* GetNative() const; #endif @@ -86,7 +83,6 @@ public: /// /// Gets class visibility /// - /// Returns visibility struct. FORCE_INLINE MVisibility GetVisibility() const { return _visibility; @@ -95,7 +91,6 @@ public: /// /// Gets if class is static /// - /// Returns true if class is static, otherwise false. FORCE_INLINE bool IsStatic() const { return _isStatic != 0; @@ -104,7 +99,6 @@ public: /// /// Gets if class is abstract /// - /// Returns true if class is static, otherwise false. FORCE_INLINE bool IsAbstract() const { return _isAbstract != 0; @@ -113,7 +107,6 @@ public: /// /// Gets if class is sealed /// - /// Returns true if class is static, otherwise false. FORCE_INLINE bool IsSealed() const { return _isSealed != 0; @@ -122,7 +115,6 @@ public: /// /// Gets if class is interface /// - /// Returns true if class is static, otherwise false. FORCE_INLINE bool IsInterface() const { return _isInterface != 0; @@ -131,19 +123,16 @@ public: /// /// Gets if class is generic /// - /// Returns true if class is generic type, otherwise false. bool IsGeneric() const; /// /// Gets the class type. /// - /// The type. MType GetType() const; /// /// Returns the base class of this class. Null if this class has no base. /// - /// The base class. MClass* GetBaseClass() const; /// @@ -170,7 +159,6 @@ public: /// /// Returns the size of an instance of this class, in bytes. /// - /// The instance size (in bytes). uint32 GetInstanceSize() const; public: diff --git a/Source/Engine/Scripting/Object.cs b/Source/Engine/Scripting/Object.cs index a4e2424e9..1b2a03fc0 100644 --- a/Source/Engine/Scripting/Object.cs +++ b/Source/Engine/Scripting/Object.cs @@ -210,6 +210,17 @@ namespace FlaxEngine return GetUnmanagedPtr(reference.Get()); } + /// + /// Gets the pointer to the native interface implementation. Handles null object reference or invalid cast (returns zero). + /// + /// The object. + /// The interface type. + /// The native interface pointer. + public static IntPtr GetUnmanagedInterface(object obj, Type type) + { + return obj is Object o ? Internal_GetUnmanagedInterface(o.__unmanagedPtr, type) : IntPtr.Zero; + } + /// public override int GetHashCode() { @@ -245,6 +256,9 @@ namespace FlaxEngine [MethodImpl(MethodImplOptions.InternalCall)] internal static extern void Internal_ChangeID(IntPtr obj, ref Guid id); + [MethodImpl(MethodImplOptions.InternalCall)] + internal static extern IntPtr Internal_GetUnmanagedInterface(IntPtr obj, Type type); + #endregion } } diff --git a/Source/Engine/Scripting/ScriptingObject.cpp b/Source/Engine/Scripting/ScriptingObject.cpp index ca4f95a51..f2b07bcb9 100644 --- a/Source/Engine/Scripting/ScriptingObject.cpp +++ b/Source/Engine/Scripting/ScriptingObject.cpp @@ -9,6 +9,7 @@ #include "Engine/Utilities/StringConverter.h" #include "Engine/Content/Asset.h" #include "Engine/Content/Content.h" +#include "Engine/Profiler/ProfilerCPU.h" #include "ManagedCLR/MAssembly.h" #include "ManagedCLR/MClass.h" #include "ManagedCLR/MUtils.h" @@ -23,6 +24,9 @@ #define ScriptingObject_unmanagedPtr "__unmanagedPtr" #define ScriptingObject_id "__internalId" +// TODO: don't leak memory (use some kind of late manual GC for those wrapper objects) +Dictionary ScriptingObjectsInterfaceWrappers; + ScriptingObject::ScriptingObject(const SpawnParams& params) : _gcHandle(0) , _type(params.Type) @@ -71,6 +75,51 @@ MClass* ScriptingObject::GetClass() const return _type ? _type.GetType().ManagedClass : nullptr; } +ScriptingObject* ScriptingObject::FromInterface(void* interfaceObj, ScriptingTypeHandle& interfaceType) +{ + if (!interfaceObj || !interfaceType) + return nullptr; + PROFILE_CPU(); + + // Find the type which implements this interface and has the same vtable as interface object + // TODO: implement vtableInterface->type hashmap caching in Scripting service to optimize sequential interface casts + auto& modules = BinaryModule::GetModules(); + for (auto module : modules) + { + for (auto& type : module->Types) + { + if (type.Type != ScriptingTypes::Script) + continue; + auto interfaceImpl = type.GetInterface(interfaceType); + if (interfaceImpl && interfaceImpl->IsNative) + { + ScriptingObject* predictedObj = (ScriptingObject*)((byte*)interfaceObj - interfaceImpl->VTableOffset); + void* predictedVTable = *(void***)predictedObj; + void* vtable = type.Script.VTable; + if (!vtable && type.GetDefaultInstance()) + { + // Use vtable from default instance of this type + vtable = *(void***)type.GetDefaultInstance(); + } + if (vtable == predictedVTable) + { + ASSERT(predictedObj->GetType().GetInterface(interfaceType)); + return predictedObj; + } + } + } + } + + // Special case for interface wrapper object + for (auto& e : ScriptingObjectsInterfaceWrappers) + { + if (e.Value == interfaceObj) + return e.Key; + } + + return nullptr; +} + ScriptingObject* ScriptingObject::ToNative(MonoObject* obj) { ScriptingObject* ptr = nullptr; @@ -575,6 +624,37 @@ public: obj->ChangeID(*id); } + static void* GetUnmanagedInterface(ScriptingObject* obj, MonoReflectionType* type) + { + if (obj && type) + { + auto typeClass = MUtils::GetClass(type); + const ScriptingTypeHandle interfaceType = ManagedBinaryModule::FindType(typeClass); + if (interfaceType) + { + const ScriptingType& objectType = obj->GetType(); + const ScriptingType::InterfaceImplementation* interface = objectType.GetInterface(interfaceType); + if (interface && interface->IsNative) + { + // Native interface so just offset pointer to the interface vtable start + return (byte*)obj + interface->VTableOffset; + } + if (interface) + { + // Interface implemented in scripting (eg. C# class inherits C++ interface) + void* result; + if (!ScriptingObjectsInterfaceWrappers.TryGet(obj, result)) + { + result = interfaceType.GetType().Interface.GetInterfaceWrapper(obj); + ScriptingObjectsInterfaceWrappers.Add(obj, result); + } + return result; + } + } + } + return nullptr; + } + static void InitRuntime() { ADD_INTERNAL_CALL("FlaxEngine.Object::Internal_Create1", &Create1); @@ -586,6 +666,7 @@ public: ADD_INTERNAL_CALL("FlaxEngine.Object::Internal_FindObject", &FindObject); ADD_INTERNAL_CALL("FlaxEngine.Object::Internal_TryFindObject", &TryFindObject); ADD_INTERNAL_CALL("FlaxEngine.Object::Internal_ChangeID", &ChangeID); + ADD_INTERNAL_CALL("FlaxEngine.Object::Internal_GetUnmanagedInterface", &GetUnmanagedInterface); } static ScriptingObject* Spawn(const ScriptingObjectSpawnParams& params) diff --git a/Source/Engine/Scripting/ScriptingObject.h b/Source/Engine/Scripting/ScriptingObject.h index 01dd84232..c33a9dc93 100644 --- a/Source/Engine/Scripting/ScriptingObject.h +++ b/Source/Engine/Scripting/ScriptingObject.h @@ -108,6 +108,8 @@ public: public: + // Tries to cast native interface object to scripting object instance. Returns null if fails. + static ScriptingObject* FromInterface(void* interfaceObj, ScriptingTypeHandle& interfaceType); static ScriptingObject* ToNative(MonoObject* obj); static MonoObject* ToManaged(ScriptingObject* obj) diff --git a/Source/Engine/Scripting/ScriptingType.h b/Source/Engine/Scripting/ScriptingType.h index 855e7f234..1e18c3e7d 100644 --- a/Source/Engine/Scripting/ScriptingType.h +++ b/Source/Engine/Scripting/ScriptingType.h @@ -117,6 +117,7 @@ struct FLAXENGINE_API ScriptingType typedef void (*Unbox)(void* ptr, MonoObject* managed); typedef void (*GetField)(void* ptr, const String& name, Variant& value); typedef void (*SetField)(void* ptr, const String& name, const Variant& value); + typedef void* (*GetInterfaceWrapper)(ScriptingObject* obj); struct InterfaceImplementation { @@ -125,6 +126,12 @@ struct FLAXENGINE_API ScriptingType // The offset (in bytes) from the object pointer to the interface implementation. Used for casting object to the interface. int16 VTableOffset; + + // The offset (in entries) from the script vtable to the interface implementation. Used for initializing interface virtual script methods. + int16 ScriptVTableOffset; + + // True if interface implementation is native (inside C++ object), otherwise it's injected at scripting level (cannot call interface directly). + bool IsNative; }; /// @@ -186,6 +193,11 @@ struct FLAXENGINE_API ScriptingType /// void** VTable; + /// + /// List of offsets from native methods VTable for each interface (with virtual methods). Null if not using interfaces with method overrides. + /// + uint16* InterfacesOffsets; + /// /// The script methods VTable used by the wrapper functions attached to native object vtable. Cached to improve C#/VisualScript invocation performance. /// @@ -244,6 +256,13 @@ struct FLAXENGINE_API ScriptingType // Class destructor method pointer Dtor Dtor; } Class; + + struct + { + SetupScriptVTableHandler SetupScriptVTable; + SetupScriptObjectVTableHandler SetupScriptObjectVTable; + GetInterfaceWrapper GetInterfaceWrapper; + } Interface; }; ScriptingType(); @@ -251,7 +270,7 @@ struct FLAXENGINE_API ScriptingType ScriptingType(const StringAnsiView& fullname, BinaryModule* module, int32 size, InitRuntimeHandler initRuntime = DefaultInitRuntime, SpawnHandler spawn = DefaultSpawn, ScriptingTypeInitializer* baseType = nullptr, SetupScriptVTableHandler setupScriptVTable = nullptr, SetupScriptObjectVTableHandler setupScriptObjectVTable = nullptr, const InterfaceImplementation* interfaces = nullptr); ScriptingType(const StringAnsiView& fullname, BinaryModule* module, int32 size, InitRuntimeHandler initRuntime, Ctor ctor, Dtor dtor, ScriptingTypeInitializer* baseType, const InterfaceImplementation* interfaces = nullptr); ScriptingType(const StringAnsiView& fullname, BinaryModule* module, int32 size, InitRuntimeHandler initRuntime, Ctor ctor, Dtor dtor, Copy copy, Box box, Unbox unbox, GetField getField, SetField setField, ScriptingTypeInitializer* baseType, const InterfaceImplementation* interfaces = nullptr); - ScriptingType(const StringAnsiView& fullname, BinaryModule* module, InitRuntimeHandler initRuntime, ScriptingTypeInitializer* baseType, const InterfaceImplementation* interfaces = nullptr); + ScriptingType(const StringAnsiView& fullname, BinaryModule* module, InitRuntimeHandler initRuntime, SetupScriptVTableHandler setupScriptVTable, SetupScriptObjectVTableHandler setupScriptObjectVTable, GetInterfaceWrapper getInterfaceWrapper); ScriptingType(const ScriptingType& other); ScriptingType(ScriptingType&& other); ScriptingType& operator=(ScriptingType&& other) = delete; @@ -292,6 +311,9 @@ struct FLAXENGINE_API ScriptingType /// const InterfaceImplementation* GetInterface(const ScriptingTypeHandle& interfaceType) const; + void SetupScriptVTable(ScriptingTypeHandle baseTypeHandle); + void SetupScriptObjectVTable(void* object, ScriptingTypeHandle baseTypeHandle, int32 wrapperIndex); + void HackObjectVTable(void* object, ScriptingTypeHandle baseTypeHandle, int32 wrapperIndex); String ToString() const; }; @@ -303,7 +325,7 @@ struct FLAXENGINE_API ScriptingTypeInitializer : ScriptingTypeHandle ScriptingTypeInitializer(BinaryModule* module, const StringAnsiView& fullname, int32 size, ScriptingType::InitRuntimeHandler initRuntime = ScriptingType::DefaultInitRuntime, ScriptingType::SpawnHandler spawn = ScriptingType::DefaultSpawn, ScriptingTypeInitializer* baseType = nullptr, ScriptingType::SetupScriptVTableHandler setupScriptVTable = nullptr, ScriptingType::SetupScriptObjectVTableHandler setupScriptObjectVTable = nullptr, const ScriptingType::InterfaceImplementation* interfaces = nullptr); ScriptingTypeInitializer(BinaryModule* module, const StringAnsiView& fullname, int32 size, ScriptingType::InitRuntimeHandler initRuntime, ScriptingType::Ctor ctor, ScriptingType::Dtor dtor, ScriptingTypeInitializer* baseType = nullptr, const ScriptingType::InterfaceImplementation* interfaces = nullptr); ScriptingTypeInitializer(BinaryModule* module, const StringAnsiView& fullname, int32 size, ScriptingType::InitRuntimeHandler initRuntime, ScriptingType::Ctor ctor, ScriptingType::Dtor dtor, ScriptingType::Copy copy, ScriptingType::Box box, ScriptingType::Unbox unbox, ScriptingType::GetField getField, ScriptingType::SetField setField, ScriptingTypeInitializer* baseType = nullptr, const ScriptingType::InterfaceImplementation* interfaces = nullptr); - ScriptingTypeInitializer(BinaryModule* module, const StringAnsiView& fullname, ScriptingType::InitRuntimeHandler initRuntime, ScriptingTypeInitializer* baseType = nullptr, const ScriptingType::InterfaceImplementation* interfaces = nullptr); + ScriptingTypeInitializer(BinaryModule* module, const StringAnsiView& fullname, ScriptingType::InitRuntimeHandler initRuntime, ScriptingType::SetupScriptVTableHandler setupScriptVTable, ScriptingType::SetupScriptObjectVTableHandler setupScriptObjectVTable, ScriptingType::GetInterfaceWrapper getInterfaceWrapper); }; /// @@ -317,7 +339,7 @@ struct ScriptingObjectSpawnParams Guid ID; /// - /// The object type handle (script class might not be loaded yet). + /// The object type handle. /// const ScriptingTypeHandle Type; diff --git a/Source/Tools/Flax.Build/Bindings/BindingsGenerator.CSharp.cs b/Source/Tools/Flax.Build/Bindings/BindingsGenerator.CSharp.cs index 05fdff485..00cfc1c80 100644 --- a/Source/Tools/Flax.Build/Bindings/BindingsGenerator.CSharp.cs +++ b/Source/Tools/Flax.Build/Bindings/BindingsGenerator.CSharp.cs @@ -233,7 +233,7 @@ namespace Flax.Build.Bindings { CSharpUsedNamespaces.Add(apiType.Namespace); - if (apiType.IsScriptingObject) + if (apiType.IsScriptingObject || apiType.IsInterface) return typeInfo.Type.Replace("::", "."); if (typeInfo.IsPtr && apiType.IsPod) return typeInfo.Type.Replace("::", ".") + '*'; @@ -256,7 +256,7 @@ namespace Flax.Build.Bindings var apiType = FindApiTypeInfo(buildData, typeInfo, caller); if (apiType != null) { - if (apiType.IsScriptingObject) + if (apiType.IsScriptingObject || apiType.IsInterface) return "IntPtr"; } @@ -297,22 +297,26 @@ namespace Flax.Build.Bindings case "PersistentScriptingObject": case "ManagedScriptingObject": // object - return "FlaxEngine.Object.GetUnmanagedPtr"; + return "FlaxEngine.Object.GetUnmanagedPtr({0})"; case "Function": // delegate - return "Marshal.GetFunctionPointerForDelegate"; + return "Marshal.GetFunctionPointerForDelegate({0})"; default: var apiType = FindApiTypeInfo(buildData, typeInfo, caller); if (apiType != null) { // Scripting Object if (apiType.IsScriptingObject) - return "FlaxEngine.Object.GetUnmanagedPtr"; + return "FlaxEngine.Object.GetUnmanagedPtr({0})"; + + // interface + if (apiType.IsInterface) + return string.Format("FlaxEngine.Object.GetUnmanagedInterface({{0}}, typeof({0}))", apiType.FullNameManaged); } // ScriptingObjectReference or AssetReference or WeakAssetReference or SoftObjectReference if ((typeInfo.Type == "ScriptingObjectReference" || typeInfo.Type == "AssetReference" || typeInfo.Type == "WeakAssetReference" || typeInfo.Type == "SoftObjectReference") && typeInfo.GenericArgs != null) - return "FlaxEngine.Object.GetUnmanagedPtr"; + return "FlaxEngine.Object.GetUnmanagedPtr({0})"; // Default return string.Empty; @@ -358,8 +362,10 @@ namespace Flax.Build.Bindings contents.Append(parameterInfo.Name); } - foreach (var parameterInfo in functionInfo.Glue.CustomParameters) + var customParametersCount = functionInfo.Glue.CustomParameters?.Count ?? 0; + for (var i = 0; i < customParametersCount; i++) { + var parameterInfo = functionInfo.Glue.CustomParameters[i]; if (separator) contents.Append(", "); separator = true; @@ -424,15 +430,14 @@ namespace Flax.Build.Bindings throw new Exception($"Cannot use Ref meta on parameter {parameterInfo} in function {functionInfo.Name} in {caller}."); // Convert value - contents.Append(convertFunc); - contents.Append('('); - contents.Append(isSetter ? "value" : parameterInfo.Name); - contents.Append(')'); + contents.Append(string.Format(convertFunc, isSetter ? "value" : parameterInfo.Name)); } } - foreach (var parameterInfo in functionInfo.Glue.CustomParameters) + var customParametersCount = functionInfo.Glue.CustomParameters?.Count ?? 0; + for (var i = 0; i < customParametersCount; i++) { + var parameterInfo = functionInfo.Glue.CustomParameters[i]; if (separator) contents.Append(", "); separator = true; @@ -451,10 +456,7 @@ namespace Flax.Build.Bindings else { // Convert value - contents.Append(convertFunc); - contents.Append('('); - contents.Append(parameterInfo.DefaultValue); - contents.Append(')'); + contents.Append(string.Format(convertFunc, parameterInfo.DefaultValue)); } } @@ -571,8 +573,24 @@ namespace Flax.Build.Bindings else if (classInfo.IsAbstract) contents.Append("abstract "); contents.Append("unsafe partial class ").Append(classInfo.Name); - if (classInfo.BaseType != null && !classInfo.IsBaseTypeHidden) + var hasBase = classInfo.BaseType != null && !classInfo.IsBaseTypeHidden; + if (hasBase) contents.Append(" : ").Append(GenerateCSharpNativeToManaged(buildData, new TypeInfo { Type = classInfo.BaseType.Name }, classInfo)); + var hasInterface = false; + if (classInfo.Interfaces != null) + { + foreach (var interfaceInfo in classInfo.Interfaces) + { + if (interfaceInfo.Access != AccessLevel.Public) + continue; + if (hasInterface || hasBase) + contents.Append(", "); + else + contents.Append(" : "); + hasInterface = true; + contents.Append(interfaceInfo.FullNameManaged); + } + } contents.AppendLine(); contents.Append(indent + "{"); indent += " "; @@ -902,6 +920,71 @@ namespace Flax.Build.Bindings GenerateCSharpWrapperFunction(buildData, contents, indent, classInfo, functionInfo); } + // Interface implementation + if (hasInterface) + { + foreach (var interfaceInfo in classInfo.Interfaces) + { + if (interfaceInfo.Access != AccessLevel.Public) + continue; + foreach (var functionInfo in interfaceInfo.Functions) + { + if (!classInfo.IsScriptingObject) + throw new Exception($"Class {classInfo.Name} cannot implement interface {interfaceInfo.Name} because it requires ScriptingObject as a base class."); + + contents.AppendLine(); + if (functionInfo.Comment.Length != 0) + contents.Append(indent).AppendLine("/// "); + GenerateCSharpAttributes(buildData, contents, indent, classInfo, functionInfo.Attributes, null, false, useUnmanaged); + contents.Append(indent); + if (functionInfo.Access == AccessLevel.Public) + contents.Append("public "); + else if (functionInfo.Access == AccessLevel.Protected) + contents.Append("protected "); + else if (functionInfo.Access == AccessLevel.Private) + contents.Append("private "); + if (functionInfo.IsVirtual && !classInfo.IsSealed) + contents.Append("virtual "); + var returnValueType = GenerateCSharpNativeToManaged(buildData, functionInfo.ReturnType, classInfo); + contents.Append(returnValueType).Append(' ').Append(functionInfo.Name).Append('('); + + for (var i = 0; i < functionInfo.Parameters.Count; i++) + { + var parameterInfo = functionInfo.Parameters[i]; + if (i != 0) + contents.Append(", "); + if (!string.IsNullOrEmpty(parameterInfo.Attributes)) + contents.Append('[').Append(parameterInfo.Attributes).Append(']').Append(' '); + + var managedType = GenerateCSharpNativeToManaged(buildData, parameterInfo.Type, classInfo); + if (parameterInfo.IsOut) + contents.Append("out "); + else if (parameterInfo.IsRef) + contents.Append("ref "); + contents.Append(managedType); + contents.Append(' '); + contents.Append(parameterInfo.Name); + + var defaultValue = GenerateCSharpDefaultValueNativeToManaged(buildData, parameterInfo.DefaultValue, classInfo); + if (!string.IsNullOrEmpty(defaultValue)) + contents.Append(" = ").Append(defaultValue); + } + + contents.Append(')').AppendLine().AppendLine(indent + "{"); + indent += " "; + + contents.Append(indent); + GenerateCSharpWrapperFunctionCall(buildData, contents, classInfo, functionInfo); + + indent = indent.Substring(0, indent.Length - 4); + contents.AppendLine(); + contents.AppendLine(indent + "}"); + + GenerateCSharpWrapperFunction(buildData, contents, indent, classInfo, functionInfo); + } + } + } + // Nested types foreach (var apiTypeInfo in classInfo.Children) { @@ -1109,6 +1192,81 @@ namespace Flax.Build.Bindings } } + private static void GenerateCSharpInterface(BuildData buildData, StringBuilder contents, string indent, InterfaceInfo interfaceInfo) + { + // Begin + contents.AppendLine(); + if (!string.IsNullOrEmpty(interfaceInfo.Namespace)) + { + contents.AppendFormat("namespace "); + contents.AppendLine(interfaceInfo.Namespace); + contents.AppendLine("{"); + indent += " "; + } + GenerateCSharpComment(contents, indent, interfaceInfo.Comment); + GenerateCSharpAttributes(buildData, contents, indent, interfaceInfo, true); + contents.Append(indent); + if (interfaceInfo.Access == AccessLevel.Public) + contents.Append("public "); + else if (interfaceInfo.Access == AccessLevel.Protected) + contents.Append("protected "); + else if (interfaceInfo.Access == AccessLevel.Private) + contents.Append("private "); + contents.Append("unsafe partial interface ").Append(interfaceInfo.Name); + contents.AppendLine(); + contents.Append(indent + "{"); + indent += " "; + + // Functions + foreach (var functionInfo in interfaceInfo.Functions) + { + if (functionInfo.IsStatic) + throw new Exception($"Not supported {"static"} function {functionInfo.Name} inside interface {interfaceInfo.Name}."); + if (functionInfo.NoProxy) + throw new Exception($"Not supported {"NoProxy"} function {functionInfo.Name} inside interface {interfaceInfo.Name}."); + if (!functionInfo.IsVirtual) + throw new Exception($"Not supported {"non-virtual"} function {functionInfo.Name} inside interface {interfaceInfo.Name}."); + if (functionInfo.Access != AccessLevel.Public) + throw new Exception($"Not supported {"non-public"} function {functionInfo.Name} inside interface {interfaceInfo.Name}."); + + contents.AppendLine(); + GenerateCSharpComment(contents, indent, functionInfo.Comment); + GenerateCSharpAttributes(buildData, contents, indent, interfaceInfo, functionInfo, true); + var returnValueType = GenerateCSharpNativeToManaged(buildData, functionInfo.ReturnType, interfaceInfo); + contents.Append(indent).Append(returnValueType).Append(' ').Append(functionInfo.Name).Append('('); + + for (var i = 0; i < functionInfo.Parameters.Count; i++) + { + var parameterInfo = functionInfo.Parameters[i]; + if (i != 0) + contents.Append(", "); + if (!string.IsNullOrEmpty(parameterInfo.Attributes)) + contents.Append('[').Append(parameterInfo.Attributes).Append(']').Append(' '); + + var managedType = GenerateCSharpNativeToManaged(buildData, parameterInfo.Type, interfaceInfo); + if (parameterInfo.IsOut) + contents.Append("out "); + else if (parameterInfo.IsRef) + contents.Append("ref "); + contents.Append(managedType); + contents.Append(' '); + contents.Append(parameterInfo.Name); + + var defaultValue = GenerateCSharpDefaultValueNativeToManaged(buildData, parameterInfo.DefaultValue, interfaceInfo); + if (!string.IsNullOrEmpty(defaultValue)) + contents.Append(" = ").Append(defaultValue); + } + + contents.Append(");").AppendLine(); + } + + // End + indent = indent.Substring(0, indent.Length - 4); + contents.AppendLine(indent + "}"); + if (!string.IsNullOrEmpty(interfaceInfo.Namespace)) + contents.AppendLine("}"); + } + private static bool GenerateCSharpType(BuildData buildData, StringBuilder contents, string indent, object type) { if (type is ApiTypeInfo apiTypeInfo && apiTypeInfo.IsInBuild) @@ -1122,6 +1280,8 @@ namespace Flax.Build.Bindings GenerateCSharpStructure(buildData, contents, indent, structureInfo); else if (type is EnumInfo enumInfo) GenerateCSharpEnum(buildData, contents, indent, enumInfo); + else if (type is InterfaceInfo interfaceInfo) + GenerateCSharpInterface(buildData, contents, indent, interfaceInfo); else return false; } diff --git a/Source/Tools/Flax.Build/Bindings/BindingsGenerator.Cache.cs b/Source/Tools/Flax.Build/Bindings/BindingsGenerator.Cache.cs index 0bb1f4cc7..079d9ea0a 100644 --- a/Source/Tools/Flax.Build/Bindings/BindingsGenerator.Cache.cs +++ b/Source/Tools/Flax.Build/Bindings/BindingsGenerator.Cache.cs @@ -19,7 +19,7 @@ namespace Flax.Build.Bindings partial class BindingsGenerator { private static readonly Dictionary TypeCache = new Dictionary(); - private const int CacheVersion = 8; + private const int CacheVersion = 9; internal static void Write(BinaryWriter writer, string e) { diff --git a/Source/Tools/Flax.Build/Bindings/BindingsGenerator.Cpp.cs b/Source/Tools/Flax.Build/Bindings/BindingsGenerator.Cpp.cs index 69ceaf6ce..c4daa9180 100644 --- a/Source/Tools/Flax.Build/Bindings/BindingsGenerator.Cpp.cs +++ b/Source/Tools/Flax.Build/Bindings/BindingsGenerator.Cpp.cs @@ -35,9 +35,9 @@ namespace Flax.Build.Bindings public static event Action, StringBuilder> GenerateCppBinaryModuleHeader; public static event Action, StringBuilder> GenerateCppBinaryModuleSource; public static event Action GenerateCppModuleSource; - public static event Action GenerateCppClassInternals; - public static event Action GenerateCppClassInitRuntime; - public static event Action GenerateCppScriptWrapperFunction; + public static event Action GenerateCppClassInternals; + public static event Action GenerateCppClassInitRuntime; + public static event Action GenerateCppScriptWrapperFunction; private static readonly List CppInBuildVariantStructures = new List { @@ -458,6 +458,13 @@ namespace Flax.Build.Bindings return "ScriptingObject::ToManaged((ScriptingObject*){0})"; } + // interface + if (apiType.IsInterface) + { + type = "MonoObject*"; + return "ScriptingObject::ToManaged(ScriptingObject::FromInterface({0}, " + apiType.NativeName + "::TypeInitializer))"; + } + // Non-POD structure passed as value (eg. it contains string or array inside) if (apiType.IsStruct && !apiType.IsPod) { @@ -977,7 +984,24 @@ namespace Flax.Build.Bindings contents.AppendLine(); } - private static void GenerateCppManagedWrapperFunction(BuildData buildData, StringBuilder contents, ClassInfo classInfo, FunctionInfo functionInfo, int scriptVTableSize, int scriptVTableIndex) + public static void GenerateCppReturn(BuildData buildData, StringBuilder contents, string indent, TypeInfo type) + { + contents.Append(indent); + if (type.IsVoid) + { + contents.AppendLine("return;"); + return; + } + if (type.IsPtr) + { + contents.AppendLine("return nullptr;"); + return; + } + contents.AppendLine($"{type} __return {{}};"); + contents.Append(indent).AppendLine("return __return;"); + } + + private static void GenerateCppManagedWrapperFunction(BuildData buildData, StringBuilder contents, VirtualClassInfo classInfo, FunctionInfo functionInfo, int scriptVTableSize, int scriptVTableIndex) { contents.AppendFormat(" {0} {1}_ManagedWrapper(", functionInfo.ReturnType, functionInfo.UniqueName); var separator = false; @@ -998,7 +1022,23 @@ namespace Flax.Build.Bindings contents.Append(')'); contents.AppendLine(); contents.AppendLine(" {"); - contents.AppendLine($" auto object = ({classInfo.NativeName}*)this;"); + string scriptVTableOffset; + if (classInfo.IsInterface) + { + contents.AppendLine($" auto object = ScriptingObject::FromInterface(this, {classInfo.NativeName}::TypeInitializer);"); + contents.AppendLine(" if (object == nullptr)"); + contents.AppendLine(" {"); + contents.AppendLine($" LOG(Error, \"Failed to cast interface {{0}} to scripting object\", TEXT(\"{classInfo.Name}\"));"); + GenerateCppReturn(buildData, contents, " ", functionInfo.ReturnType); + contents.AppendLine(" }"); + contents.AppendLine($" const int32 scriptVTableOffset = {scriptVTableIndex} + object->GetType().GetInterface({classInfo.NativeName}::TypeInitializer)->ScriptVTableOffset;"); + scriptVTableOffset = "scriptVTableOffset"; + } + else + { + contents.AppendLine($" auto object = ({classInfo.NativeName}*)this;"); + scriptVTableOffset = scriptVTableIndex.ToString(); + } contents.AppendLine(" static THREADLOCAL void* WrapperCallInstance = nullptr;"); contents.AppendLine(" ScriptingTypeHandle managedTypeHandle = object->GetTypeHandle();"); @@ -1013,7 +1053,7 @@ namespace Flax.Build.Bindings contents.AppendLine(" {"); contents.AppendLine(" // Prevent stack overflow by calling native base method"); contents.AppendLine(" const auto scriptVTableBase = managedTypePtr->Script.ScriptVTableBase;"); - contents.Append($" return (this->**({functionInfo.UniqueName}_Internal_Signature*)&scriptVTableBase[{scriptVTableIndex} + 2])("); + contents.Append($" return (this->**({functionInfo.UniqueName}_Internal_Signature*)&scriptVTableBase[{scriptVTableOffset} + 2])("); separator = false; for (var i = 0; i < functionInfo.Parameters.Count; i++) { @@ -1026,8 +1066,8 @@ namespace Flax.Build.Bindings contents.AppendLine(");"); contents.AppendLine(" }"); contents.AppendLine(" auto scriptVTable = (MMethod**)managedTypePtr->Script.ScriptVTable;"); - contents.AppendLine($" ASSERT(scriptVTable && scriptVTable[{scriptVTableIndex}]);"); - contents.AppendLine($" auto method = scriptVTable[{scriptVTableIndex}];"); + contents.AppendLine($" ASSERT(scriptVTable && scriptVTable[{scriptVTableOffset}]);"); + contents.AppendLine($" auto method = scriptVTable[{scriptVTableOffset}];"); contents.AppendLine($" PROFILE_CPU_NAMED(\"{classInfo.FullNameManaged}::{functionInfo.Name}\");"); contents.AppendLine(" MonoObject* exception = nullptr;"); @@ -1122,6 +1162,107 @@ namespace Flax.Build.Bindings contents.AppendLine(); } + private static string GenerateCppScriptVTable(BuildData buildData, StringBuilder contents, VirtualClassInfo classInfo) + { + var scriptVTableSize = classInfo.GetScriptVTableSize(out var scriptVTableOffset); + if (scriptVTableSize == 0) + return "nullptr, nullptr"; + + var scriptVTableIndex = scriptVTableOffset; + foreach (var functionInfo in classInfo.Functions) + { + if (!functionInfo.IsVirtual) + continue; + GenerateCppManagedWrapperFunction(buildData, contents, classInfo, functionInfo, scriptVTableSize, scriptVTableIndex); + GenerateCppScriptWrapperFunction?.Invoke(buildData, classInfo, functionInfo, scriptVTableSize, scriptVTableIndex, contents); + scriptVTableIndex++; + } + + CppIncludeFiles.Add("Engine/Scripting/ManagedCLR/MMethod.h"); + CppIncludeFiles.Add("Engine/Scripting/ManagedCLR/MClass.h"); + + foreach (var functionInfo in classInfo.Functions) + { + if (!functionInfo.IsVirtual) + continue; + var thunkParams = string.Empty; + var separator = false; + for (var i = 0; i < functionInfo.Parameters.Count; i++) + { + var parameterInfo = functionInfo.Parameters[i]; + if (separator) + thunkParams += ", "; + separator = true; + thunkParams += parameterInfo.Type; + } + var t = functionInfo.IsConst ? " const" : string.Empty; + contents.AppendLine($" typedef {functionInfo.ReturnType} ({classInfo.NativeName}::*{functionInfo.UniqueName}_Signature)({thunkParams}){t};"); + contents.AppendLine($" typedef {functionInfo.ReturnType} ({classInfo.NativeName}Internal::*{functionInfo.UniqueName}_Internal_Signature)({thunkParams}){t};"); + } + contents.AppendLine(""); + + contents.AppendLine(" static void SetupScriptVTable(MClass* mclass, void**& scriptVTable, void**& scriptVTableBase)"); + contents.AppendLine(" {"); + if (classInfo.IsInterface) + { + contents.AppendLine(" ASSERT(scriptVTable);"); + } + else + { + contents.AppendLine(" if (!scriptVTable)"); + contents.AppendLine(" {"); + contents.AppendLine($" scriptVTable = (void**)Platform::Allocate(sizeof(void*) * {scriptVTableSize + 1}, 16);"); + contents.AppendLine($" Platform::MemoryClear(scriptVTable, sizeof(void*) * {scriptVTableSize + 1});"); + contents.AppendLine($" scriptVTableBase = (void**)Platform::Allocate(sizeof(void*) * {scriptVTableSize + 2}, 16);"); + contents.AppendLine(" }"); + } + scriptVTableIndex = scriptVTableOffset; + foreach (var functionInfo in classInfo.Functions) + { + if (!functionInfo.IsVirtual) + continue; + contents.AppendLine($" scriptVTable[{scriptVTableIndex++}] = mclass->GetMethod(\"{functionInfo.Name}\", {functionInfo.Parameters.Count});"); + } + contents.AppendLine(" }"); + contents.AppendLine(""); + + contents.AppendLine(" static void SetupScriptObjectVTable(void** scriptVTable, void** scriptVTableBase, void** vtable, int32 entriesCount, int32 wrapperIndex)"); + contents.AppendLine(" {"); + scriptVTableIndex = scriptVTableOffset; + foreach (var functionInfo in classInfo.Functions) + { + if (!functionInfo.IsVirtual) + continue; + contents.AppendLine($" if (scriptVTable[{scriptVTableIndex}])"); + contents.AppendLine(" {"); + contents.AppendLine($" {functionInfo.UniqueName}_Signature funcPtr = &{classInfo.NativeName}::{functionInfo.Name};"); + contents.AppendLine(" const int32 vtableIndex = GetVTableIndex(vtable, entriesCount, *(void**)&funcPtr);"); + contents.AppendLine(" if (vtableIndex > 0 && vtableIndex < entriesCount)"); + contents.AppendLine(" {"); + contents.AppendLine($" scriptVTableBase[{scriptVTableIndex} + 2] = vtable[vtableIndex];"); + for (var i = 0; i < CppScriptObjectVirtualWrapperMethodsPostfixes.Count; i++) + { + contents.AppendLine(i == 0 ? " if (wrapperIndex == 0)" : $" else if (wrapperIndex == {i})"); + contents.AppendLine(" {"); + contents.AppendLine($" auto thunkPtr = &{classInfo.NativeName}Internal::{functionInfo.UniqueName}{CppScriptObjectVirtualWrapperMethodsPostfixes[i]};"); + contents.AppendLine(" vtable[vtableIndex] = *(void**)&thunkPtr;"); + contents.AppendLine(" }"); + } + contents.AppendLine(" }"); + contents.AppendLine(" else"); + contents.AppendLine(" {"); + contents.AppendLine($" LOG(Error, \"Failed to find the vtable entry for method {{0}} in class {{1}}\", TEXT(\"{functionInfo.Name}\"), TEXT(\"{classInfo.Name}\"));"); + contents.AppendLine(" }"); + contents.AppendLine(" }"); + + scriptVTableIndex++; + } + contents.AppendLine(" }"); + contents.AppendLine(""); + + return $"&{classInfo.NativeName}Internal::SetupScriptVTable, &{classInfo.NativeName}Internal::SetupScriptObjectVTable"; + } + private static string GenerateCppAutoSerializationDefineType(BuildData buildData, StringBuilder contents, ModuleInfo moduleInfo, ApiTypeInfo caller, TypeInfo memberType, MemberInfo member) { if (memberType.IsBitField) @@ -1225,7 +1366,7 @@ namespace Flax.Build.Bindings contents.Append('}').AppendLine(); } - private static string GenerateCppInterfaceInheritanceTable(BuildData buildData, StringBuilder contents, ModuleInfo moduleInfo, ClassStructInfo typeInfo, string typeNameNative) + private static string GenerateCppInterfaceInheritanceTable(BuildData buildData, StringBuilder contents, ModuleInfo moduleInfo, VirtualClassInfo typeInfo, string typeNameNative) { var interfacesPtr = "nullptr"; var interfaces = typeInfo.Interfaces; @@ -1236,7 +1377,8 @@ namespace Flax.Build.Bindings for (int i = 0; i < interfaces.Count; i++) { var interfaceInfo = interfaces[i]; - contents.Append(" { &").Append(interfaceInfo.NativeName).Append("::TypeInitializer, (int16)VTABLE_OFFSET(").Append(typeInfo.NativeName).Append(", ").Append(interfaceInfo.NativeName).AppendLine(") },"); + var scriptVTableOffset = typeInfo.GetScriptVTableOffset(interfaceInfo); + contents.AppendLine($" {{ &{interfaceInfo.NativeName}::TypeInitializer, (int16)VTABLE_OFFSET({typeInfo.NativeName}, {interfaceInfo.NativeName}), {scriptVTableOffset}, true }},"); } contents.AppendLine(" { nullptr, 0 },"); contents.AppendLine("};"); @@ -1253,6 +1395,7 @@ namespace Flax.Build.Bindings if (classInfo.Parent != null && !(classInfo.Parent is FileInfo)) classTypeNameInternal = classInfo.Parent.FullNameNative + '_' + classTypeNameInternal; var useScripting = classInfo.IsStatic || classInfo.IsScriptingObject; + var hasInterface = classInfo.Interfaces != null && classInfo.Interfaces.Any(x => x.Access == AccessLevel.Public); if (classInfo.IsAutoSerialization) GenerateCppAutoSerialization(buildData, contents, moduleInfo, classInfo, classTypeNameNative); @@ -1424,110 +1567,27 @@ namespace Flax.Build.Bindings GenerateCppWrapperFunction(buildData, contents, classInfo, functionInfo); } - GenerateCppClassInternals?.Invoke(buildData, classInfo, contents); - - // Virtual methods overrides - var setupScriptVTable = "nullptr, nullptr"; - if (!classInfo.IsSealed) + // Interface implementation + if (hasInterface) { - var scriptVTableSize = classInfo.GetScriptVTableSize(buildData, out var scriptVTableOffset); - if (scriptVTableSize > 0) + foreach (var interfaceInfo in classInfo.Interfaces) { - var scriptVTableIndex = scriptVTableOffset; - foreach (var functionInfo in classInfo.Functions) + if (interfaceInfo.Access != AccessLevel.Public) + continue; + foreach (var functionInfo in interfaceInfo.Functions) { - if (functionInfo.IsVirtual) - { - GenerateCppManagedWrapperFunction(buildData, contents, classInfo, functionInfo, scriptVTableSize, scriptVTableIndex); - GenerateCppScriptWrapperFunction?.Invoke(buildData, classInfo, functionInfo, scriptVTableSize, scriptVTableIndex, contents); - scriptVTableIndex++; - } - } - - if (scriptVTableOffset != scriptVTableSize) - { - CppIncludeFiles.Add("Engine/Scripting/ManagedCLR/MMethod.h"); - CppIncludeFiles.Add("Engine/Scripting/ManagedCLR/MClass.h"); - - foreach (var functionInfo in classInfo.Functions) - { - if (!functionInfo.IsVirtual) - continue; - - var thunkParams = string.Empty; - var separator = false; - for (var i = 0; i < functionInfo.Parameters.Count; i++) - { - var parameterInfo = functionInfo.Parameters[i]; - if (separator) - thunkParams += ", "; - separator = true; - thunkParams += parameterInfo.Type; - } - var t = functionInfo.IsConst ? " const" : string.Empty; - contents.AppendLine($" typedef {functionInfo.ReturnType} ({classInfo.NativeName}::*{functionInfo.UniqueName}_Signature)({thunkParams}){t};"); - contents.AppendLine($" typedef {functionInfo.ReturnType} ({classInfo.NativeName}Internal::*{functionInfo.UniqueName}_Internal_Signature)({thunkParams}){t};"); - } - contents.AppendLine(""); - - contents.AppendLine(" static void SetupScriptVTable(MClass* mclass, void**& scriptVTable, void**& scriptVTableBase)"); - contents.AppendLine(" {"); - contents.AppendLine(" if (!scriptVTable)"); - contents.AppendLine(" {"); - contents.AppendLine($" scriptVTable = (void**)Platform::Allocate(sizeof(void*) * {scriptVTableSize + 1}, 16);"); - contents.AppendLine($" Platform::MemoryClear(scriptVTable, sizeof(void*) * {scriptVTableSize + 1});"); - contents.AppendLine($" scriptVTableBase = (void**)Platform::Allocate(sizeof(void*) * {scriptVTableSize + 2}, 16);"); - contents.AppendLine(" }"); - scriptVTableIndex = scriptVTableOffset; - foreach (var functionInfo in classInfo.Functions) - { - if (!functionInfo.IsVirtual) - continue; - contents.AppendLine($" scriptVTable[{scriptVTableIndex++}] = mclass->GetMethod(\"{functionInfo.Name}\", {functionInfo.Parameters.Count});"); - } - contents.AppendLine(" }"); - contents.AppendLine(""); - - contents.AppendLine(" static void SetupScriptObjectVTable(void** scriptVTable, void** scriptVTableBase, void** vtable, int32 entriesCount, int32 wrapperIndex)"); - contents.AppendLine(" {"); - scriptVTableIndex = scriptVTableOffset; - foreach (var functionInfo in classInfo.Functions) - { - if (!functionInfo.IsVirtual) - continue; - - contents.AppendLine($" if (scriptVTable[{scriptVTableIndex}])"); - contents.AppendLine(" {"); - contents.AppendLine($" {functionInfo.UniqueName}_Signature funcPtr = &{classInfo.NativeName}::{functionInfo.Name};"); - contents.AppendLine(" const int32 vtableIndex = GetVTableIndex(vtable, entriesCount, *(void**)&funcPtr);"); - contents.AppendLine(" if (vtableIndex > 0 && vtableIndex < entriesCount)"); - contents.AppendLine(" {"); - contents.AppendLine($" scriptVTableBase[{scriptVTableIndex} + 2] = vtable[vtableIndex];"); - for (var i = 0; i < CppScriptObjectVirtualWrapperMethodsPostfixes.Count; i++) - { - contents.AppendLine(i == 0 ? " if (wrapperIndex == 0)" : $" else if (wrapperIndex == {i})"); - contents.AppendLine(" {"); - contents.AppendLine($" auto thunkPtr = &{classInfo.NativeName}Internal::{functionInfo.UniqueName}{CppScriptObjectVirtualWrapperMethodsPostfixes[i]};"); - contents.AppendLine(" vtable[vtableIndex] = *(void**)&thunkPtr;"); - contents.AppendLine(" }"); - } - contents.AppendLine(" }"); - contents.AppendLine(" else"); - contents.AppendLine(" {"); - contents.AppendLine($" LOG(Error, \"Failed to find the vtable entry for method {{0}} in class {{1}}\", TEXT(\"{functionInfo.Name}\"), TEXT(\"{classInfo.Name}\"));"); - contents.AppendLine(" }"); - contents.AppendLine(" }"); - - scriptVTableIndex++; - } - contents.AppendLine(" }"); - contents.AppendLine(""); - - setupScriptVTable = $"&{classInfo.NativeName}Internal::SetupScriptVTable, &{classInfo.NativeName}Internal::SetupScriptObjectVTable"; + if (!classInfo.IsScriptingObject) + throw new Exception($"Class {classInfo.Name} cannot implement interface {interfaceInfo.Name} because it requires ScriptingObject as a base class."); + GenerateCppWrapperFunction(buildData, contents, classInfo, functionInfo); } } } + GenerateCppClassInternals?.Invoke(buildData, classInfo, contents); + + // Virtual methods overrides + var setupScriptVTable = GenerateCppScriptVTable(buildData, contents, classInfo); + // Runtime initialization (internal methods binding) contents.AppendLine(" static void InitRuntime()"); contents.AppendLine(" {"); @@ -1558,6 +1618,18 @@ namespace Flax.Build.Bindings { contents.AppendLine($" ADD_INTERNAL_CALL(\"{classTypeNameManagedInternalCall}::Internal_{functionInfo.UniqueName}\", &{functionInfo.UniqueName});"); } + if (hasInterface) + { + foreach (var interfaceInfo in classInfo.Interfaces) + { + if (interfaceInfo.Access != AccessLevel.Public) + continue; + foreach (var functionInfo in interfaceInfo.Functions) + { + contents.AppendLine($" ADD_INTERNAL_CALL(\"{classTypeNameManagedInternalCall}::Internal_{functionInfo.UniqueName}\", &{functionInfo.UniqueName});"); + } + } + } } GenerateCppClassInitRuntime?.Invoke(buildData, classInfo, contents); @@ -1811,29 +1883,102 @@ namespace Flax.Build.Bindings { var interfaceTypeNameNative = interfaceInfo.FullNameNative; var interfaceTypeNameManaged = interfaceInfo.FullNameManaged; - var interfaceTypeNameManagedInternalCall = interfaceTypeNameManaged.Replace('+', '/'); var interfaceTypeNameInternal = interfaceInfo.NativeName; if (interfaceInfo.Parent != null && !(interfaceInfo.Parent is FileInfo)) interfaceTypeNameInternal = interfaceInfo.Parent.FullNameNative + '_' + interfaceTypeNameInternal; + // Wrapper interface implement to invoke scripting if inherited in C# or VS + contents.AppendLine(); + contents.AppendFormat("class {0}Wrapper : public ", interfaceTypeNameInternal).Append(interfaceTypeNameNative).AppendLine(); + contents.Append('{').AppendLine(); + contents.AppendLine("public:"); + contents.AppendLine(" ScriptingObject* Object;"); + foreach (var functionInfo in interfaceInfo.Functions) + { + if (!functionInfo.IsVirtual) + continue; + contents.AppendFormat(" {0} {1}(", functionInfo.ReturnType, functionInfo.Name); + var separator = false; + for (var i = 0; i < functionInfo.Parameters.Count; i++) + { + var parameterInfo = functionInfo.Parameters[i]; + if (separator) + contents.Append(", "); + separator = true; + contents.Append(parameterInfo.Type).Append(' ').Append(parameterInfo.Name); + } + contents.Append(") override").AppendLine(); + contents.AppendLine(" {"); + // TODO: try to use ScriptVTable for interfaces implementation in scripting to call proper function instead of manually check at runtime + if (functionInfo.Parameters.Count != 0) + { + contents.AppendLine($" Variant parameters[{functionInfo.Parameters.Count}];"); + for (var i = 0; i < functionInfo.Parameters.Count; i++) + { + var parameterInfo = functionInfo.Parameters[i]; + contents.AppendLine($" parameters[{i}] = {GenerateCppWrapperNativeToVariant(buildData, parameterInfo.Type, interfaceInfo, parameterInfo.Name)};"); + } + } + else + { + contents.AppendLine(" Variant* parameters = nullptr;"); + } + contents.AppendLine(" auto typeHandle = Object->GetTypeHandle();"); + contents.AppendLine(" while (typeHandle)"); + contents.AppendLine(" {"); + contents.AppendLine($" auto method = typeHandle.Module->FindMethod(typeHandle, \"{functionInfo.Name}\", {functionInfo.Parameters.Count});"); + contents.AppendLine(" if (method)"); + contents.AppendLine(" {"); + contents.AppendLine(" Variant __result;"); + contents.AppendLine($" typeHandle.Module->InvokeMethod(method, Object, Span(parameters, {functionInfo.Parameters.Count}), __result);"); + if (functionInfo.ReturnType.IsVoid) + contents.AppendLine(" return;"); + else + contents.AppendLine($" return {GenerateCppWrapperVariantToNative(buildData, functionInfo.ReturnType, interfaceInfo, "__result")};"); + contents.AppendLine(" }"); + contents.AppendLine(" typeHandle = typeHandle.GetType().GetBaseType();"); + contents.AppendLine(" }"); + GenerateCppReturn(buildData, contents, " ", functionInfo.ReturnType); + contents.AppendLine(" }"); + } + if (interfaceInfo.Name == "ISerializable") + { + // TODO: how to handle other interfaces that have some abstract native methods? maybe NativeOnly tag on interface? do it right and remove this hack + contents.AppendLine(" void Serialize(SerializeStream& stream, const void* otherObj) override {} void Deserialize(DeserializeStream& stream, ISerializeModifier* modifier) override {}"); + } + contents.Append('}').Append(';').AppendLine(); + contents.AppendLine(); contents.AppendFormat("class {0}Internal", interfaceTypeNameInternal).AppendLine(); contents.Append('{').AppendLine(); contents.AppendLine("public:"); + GenerateCppClassInternals?.Invoke(buildData, interfaceInfo, contents); + + // Virtual methods overrides + var setupScriptVTable = GenerateCppScriptVTable(buildData, contents, interfaceInfo); + // Runtime initialization (internal methods binding) contents.AppendLine(" static void InitRuntime()"); contents.AppendLine(" {"); + GenerateCppClassInitRuntime?.Invoke(buildData, interfaceInfo, contents); contents.AppendLine(" }").AppendLine(); + // Interface implementation wrapper accessor for scripting types + contents.AppendLine(" static void* GetInterfaceWrapper(ScriptingObject* obj)"); + contents.AppendLine(" {"); + contents.AppendLine($" auto wrapper = New<{interfaceTypeNameInternal}Wrapper>();"); + contents.AppendLine(" wrapper->Object = obj;"); + contents.AppendLine(" return wrapper;"); + contents.AppendLine(" }"); + contents.Append('}').Append(';').AppendLine(); contents.AppendLine(); // Type initializer contents.Append($"ScriptingTypeInitializer {interfaceTypeNameNative}::TypeInitializer((BinaryModule*)GetBinaryModule{moduleInfo.Name}(), "); - contents.Append($"StringAnsiView(\"{interfaceTypeNameManaged}\", {interfaceTypeNameManaged.Length}), "); - contents.Append($"&{interfaceTypeNameInternal}Internal::InitRuntime"); - contents.Append(");"); + contents.Append($"StringAnsiView(\"{interfaceTypeNameManaged}\", {interfaceTypeNameManaged.Length}), &{interfaceTypeNameInternal}Internal::InitRuntime,"); + contents.Append(setupScriptVTable).Append($", &{interfaceTypeNameInternal}Internal::GetInterfaceWrapper").Append(");"); contents.AppendLine(); // Nested types diff --git a/Source/Tools/Flax.Build/Bindings/BindingsGenerator.cs b/Source/Tools/Flax.Build/Bindings/BindingsGenerator.cs index fa219477a..6645bc484 100644 --- a/Source/Tools/Flax.Build/Bindings/BindingsGenerator.cs +++ b/Source/Tools/Flax.Build/Bindings/BindingsGenerator.cs @@ -243,6 +243,8 @@ namespace Flax.Build.Bindings classInfo.Functions.Add(functionInfo); else if (context.ScopeInfo is StructureInfo structureInfo) structureInfo.Functions.Add(functionInfo); + else if (context.ScopeInfo is InterfaceInfo interfaceInfo) + interfaceInfo.Functions.Add(functionInfo); else throw new Exception($"Not supported free-function {functionInfo.Name} at line {tokenizer.CurrentLine}. Place it in the class to use API bindings for it."); } diff --git a/Source/Tools/Flax.Build/Bindings/ClassInfo.cs b/Source/Tools/Flax.Build/Bindings/ClassInfo.cs index f591dc004..c762b927c 100644 --- a/Source/Tools/Flax.Build/Bindings/ClassInfo.cs +++ b/Source/Tools/Flax.Build/Bindings/ClassInfo.cs @@ -10,7 +10,7 @@ namespace Flax.Build.Bindings /// /// The native class information for bindings generator. /// - public class ClassInfo : ClassStructInfo + public class ClassInfo : VirtualClassInfo { private static readonly HashSet InBuildScriptingObjectTypes = new HashSet { @@ -31,13 +31,10 @@ namespace Flax.Build.Bindings public bool IsAutoSerialization; public bool NoSpawn; public bool NoConstructor; - public List Functions = new List(); public List Properties = new List(); public List Fields = new List(); public List Events = new List(); - internal HashSet UniqueFunctionNames; - private bool _isScriptingObject; private int _scriptVTableSize = -1; private int _scriptVTableOffset; @@ -132,21 +129,6 @@ namespace Flax.Build.Bindings if (propertyInfo.Setter != null) ProcessAndValidate(propertyInfo.Setter); } - - foreach (var functionInfo in Functions) - ProcessAndValidate(functionInfo); - } - - private void ProcessAndValidate(FunctionInfo functionInfo) - { - // Ensure that methods have unique names for bindings - if (UniqueFunctionNames == null) - UniqueFunctionNames = new HashSet(); - int idx = 1; - functionInfo.UniqueName = functionInfo.Name; - while (UniqueFunctionNames.Contains(functionInfo.UniqueName)) - functionInfo.UniqueName = functionInfo.Name + idx++; - UniqueFunctionNames.Add(functionInfo.UniqueName); } public override void Write(BinaryWriter writer) @@ -158,7 +140,6 @@ namespace Flax.Build.Bindings writer.Write(IsAutoSerialization); writer.Write(NoSpawn); writer.Write(NoConstructor); - BindingsGenerator.Write(writer, Functions); BindingsGenerator.Write(writer, Properties); BindingsGenerator.Write(writer, Fields); BindingsGenerator.Write(writer, Events); @@ -175,7 +156,6 @@ namespace Flax.Build.Bindings IsAutoSerialization = reader.ReadBoolean(); NoSpawn = reader.ReadBoolean(); NoConstructor = reader.ReadBoolean(); - Functions = BindingsGenerator.Read(reader, Functions); Properties = BindingsGenerator.Read(reader, Properties); Fields = BindingsGenerator.Read(reader, Fields); Events = BindingsGenerator.Read(reader, Events); @@ -183,25 +163,51 @@ namespace Flax.Build.Bindings base.Read(reader); } - public int GetScriptVTableSize(Builder.BuildData buildData, out int offset) + public override int GetScriptVTableSize(out int offset) { if (_scriptVTableSize == -1) { if (BaseType is ClassInfo baseApiTypeInfo) { - _scriptVTableOffset = baseApiTypeInfo.GetScriptVTableSize(buildData, out _); + _scriptVTableOffset = baseApiTypeInfo.GetScriptVTableSize(out _); + } + if (Interfaces != null) + { + foreach (var interfaceInfo in Interfaces) + { + if (interfaceInfo.Access != AccessLevel.Public) + continue; + _scriptVTableOffset += interfaceInfo.GetScriptVTableSize(out _); + } } _scriptVTableSize = _scriptVTableOffset + Functions.Count(x => x.IsVirtual); + if (IsSealed) + { + // Skip vtables for sealed classes + _scriptVTableSize = _scriptVTableOffset = 0; + } } offset = _scriptVTableOffset; return _scriptVTableSize; } - public override void AddChild(ApiTypeInfo apiTypeInfo) + public override int GetScriptVTableOffset(VirtualClassInfo classInfo) { - apiTypeInfo.Namespace = null; - - base.AddChild(apiTypeInfo); + if (classInfo == BaseType) + return 0; + if (Interfaces != null) + { + var offset = BaseType is ClassInfo baseApiTypeInfo ? baseApiTypeInfo.GetScriptVTableSize(out _) : 0; + foreach (var interfaceInfo in Interfaces) + { + if (interfaceInfo.Access != AccessLevel.Public) + continue; + if (interfaceInfo == classInfo) + return offset; + offset += interfaceInfo.GetScriptVTableSize(out _); + } + } + throw new Exception($"Cannot get Script VTable offset for {classInfo} that is not part of {this}"); } public override string ToString() diff --git a/Source/Tools/Flax.Build/Bindings/ClassStructInfo.cs b/Source/Tools/Flax.Build/Bindings/ClassStructInfo.cs index ca1bb9257..b3871ad24 100644 --- a/Source/Tools/Flax.Build/Bindings/ClassStructInfo.cs +++ b/Source/Tools/Flax.Build/Bindings/ClassStructInfo.cs @@ -66,4 +66,59 @@ namespace Flax.Build.Bindings base.Read(reader); } } + + /// + /// The native class or interface information for bindings generator that contains virtual functions. + /// + public abstract class VirtualClassInfo : ClassStructInfo + { + public List Functions = new List(); + + internal HashSet UniqueFunctionNames; + + public override void Init(Builder.BuildData buildData) + { + base.Init(buildData); + + foreach (var functionInfo in Functions) + ProcessAndValidate(functionInfo); + } + + protected void ProcessAndValidate(FunctionInfo functionInfo) + { + // Ensure that methods have unique names for bindings + if (UniqueFunctionNames == null) + UniqueFunctionNames = new HashSet(); + int idx = 1; + functionInfo.UniqueName = functionInfo.Name; + while (UniqueFunctionNames.Contains(functionInfo.UniqueName)) + functionInfo.UniqueName = functionInfo.Name + idx++; + UniqueFunctionNames.Add(functionInfo.UniqueName); + } + + public abstract int GetScriptVTableSize(out int offset); + + public abstract int GetScriptVTableOffset(VirtualClassInfo classInfo); + + public override void Write(BinaryWriter writer) + { + BindingsGenerator.Write(writer, Functions); + + base.Write(writer); + } + + public override void Read(BinaryReader reader) + { + Functions = BindingsGenerator.Read(reader, Functions); + + base.Read(reader); + } + + public override void AddChild(ApiTypeInfo apiTypeInfo) + { + apiTypeInfo.Namespace = null; + + base.AddChild(apiTypeInfo); + } + } } diff --git a/Source/Tools/Flax.Build/Bindings/InterfaceInfo.cs b/Source/Tools/Flax.Build/Bindings/InterfaceInfo.cs index e0225f288..70032c596 100644 --- a/Source/Tools/Flax.Build/Bindings/InterfaceInfo.cs +++ b/Source/Tools/Flax.Build/Bindings/InterfaceInfo.cs @@ -1,19 +1,36 @@ // Copyright (c) 2012-2021 Wojciech Figat. All rights reserved. +using System; +using System.Linq; + namespace Flax.Build.Bindings { /// /// The native class/structure interface information for bindings generator. /// - public class InterfaceInfo : ClassStructInfo + public class InterfaceInfo : VirtualClassInfo { + public override int GetScriptVTableSize(out int offset) + { + offset = 0; + return Functions.Count(x => x.IsVirtual); + } + + public override int GetScriptVTableOffset(VirtualClassInfo classInfo) + { + return 0; + } + public override bool IsInterface => true; - public override void AddChild(ApiTypeInfo apiTypeInfo) + public override void Init(Builder.BuildData buildData) { - apiTypeInfo.Namespace = null; + base.Init(buildData); - base.AddChild(apiTypeInfo); + if (BaseType != null) + throw new Exception(string.Format("Interface {0} cannot inherit from {1}.", FullNameNative, BaseType)); + if (Interfaces != null && Interfaces.Count != 0) + throw new Exception(string.Format("Interface {0} cannot inherit from {1}.", FullNameNative, "interfaces")); } public override string ToString() diff --git a/Source/Tools/Flax.Build/Build/Plugins/VisualScriptingPlugin.cs b/Source/Tools/Flax.Build/Build/Plugins/VisualScriptingPlugin.cs index 95fd60c6e..c22812960 100644 --- a/Source/Tools/Flax.Build/Build/Plugins/VisualScriptingPlugin.cs +++ b/Source/Tools/Flax.Build/Build/Plugins/VisualScriptingPlugin.cs @@ -20,7 +20,7 @@ namespace Flax.Build.Plugins BindingsGenerator.CppScriptObjectVirtualWrapperMethodsPostfixes.Add("_VisualScriptWrapper"); } - private void OnGenerateCppScriptWrapperFunction(Builder.BuildData buildData, ClassInfo classInfo, FunctionInfo functionInfo, int scriptVTableSize, int scriptVTableIndex, StringBuilder contents) + private void OnGenerateCppScriptWrapperFunction(Builder.BuildData buildData, VirtualClassInfo classInfo, FunctionInfo functionInfo, int scriptVTableSize, int scriptVTableIndex, StringBuilder contents) { // Generate C++ wrapper function to invoke Visual Script instead of overridden native function (with support for base method callback) @@ -43,13 +43,29 @@ namespace Flax.Build.Plugins contents.Append(')'); contents.AppendLine(); contents.AppendLine(" {"); - contents.AppendLine($" auto object = ({classInfo.NativeName}*)this;"); + string scriptVTableOffset; + if (classInfo.IsInterface) + { + contents.AppendLine($" auto object = ScriptingObject::FromInterface(this, {classInfo.NativeName}::TypeInitializer);"); + contents.AppendLine(" if (object == nullptr)"); + contents.AppendLine(" {"); + contents.AppendLine($" LOG(Error, \"Failed to cast interface {{0}} to scripting object\", TEXT(\"{classInfo.Name}\"));"); + BindingsGenerator.GenerateCppReturn(buildData, contents, " ", functionInfo.ReturnType); + contents.AppendLine(" }"); + contents.AppendLine($" const int32 scriptVTableOffset = {scriptVTableIndex} + object->GetType().GetInterface({classInfo.NativeName}::TypeInitializer)->ScriptVTableOffset;"); + scriptVTableOffset = "scriptVTableOffset"; + } + else + { + contents.AppendLine($" auto object = ({classInfo.NativeName}*)this;"); + scriptVTableOffset = scriptVTableIndex.ToString(); + } contents.AppendLine(" static THREADLOCAL void* WrapperCallInstance = nullptr;"); contents.AppendLine(" if (WrapperCallInstance == object)"); contents.AppendLine(" {"); contents.AppendLine(" // Prevent stack overflow by calling base method"); contents.AppendLine(" const auto scriptVTableBase = object->GetType().Script.ScriptVTableBase;"); - contents.Append($" return (this->**({functionInfo.UniqueName}_Internal_Signature*)&scriptVTableBase[{scriptVTableIndex} + 2])("); + contents.Append($" return (this->**({functionInfo.UniqueName}_Internal_Signature*)&scriptVTableBase[{scriptVTableOffset} + 2])("); separator = false; for (var i = 0; i < functionInfo.Parameters.Count; i++) { @@ -62,7 +78,7 @@ namespace Flax.Build.Plugins contents.AppendLine(");"); contents.AppendLine(" }"); contents.AppendLine(" auto scriptVTable = (VisualScript::Method**)object->GetType().Script.ScriptVTable;"); - contents.AppendLine($" ASSERT(scriptVTable && scriptVTable[{scriptVTableIndex}]);"); + contents.AppendLine($" ASSERT(scriptVTable && scriptVTable[{scriptVTableOffset}]);"); if (functionInfo.Parameters.Count != 0) { @@ -80,7 +96,7 @@ namespace Flax.Build.Plugins contents.AppendLine(" auto prevWrapperCallInstance = WrapperCallInstance;"); contents.AppendLine(" WrapperCallInstance = object;"); - contents.AppendLine($" auto __result = VisualScripting::Invoke(scriptVTable[{scriptVTableIndex}], object, Span(parameters, {functionInfo.Parameters.Count}));"); + contents.AppendLine($" auto __result = VisualScripting::Invoke(scriptVTable[{scriptVTableOffset}], object, Span(parameters, {functionInfo.Parameters.Count}));"); contents.AppendLine(" WrapperCallInstance = prevWrapperCallInstance;"); if (!functionInfo.ReturnType.IsVoid)