From bf20f5d2bfd6977e3b70278a1d92f80c2f228043 Mon Sep 17 00:00:00 2001 From: Wojtek Figat Date: Mon, 30 Mar 2026 19:54:47 +0200 Subject: [PATCH] Fix scripting bindings in searching virtual methods to invoke when there is a name and parameter count collision #3649 --- Source/Engine/Scripting/BinaryModule.cpp | 64 +++------- Source/Engine/Scripting/BinaryModule.h | 4 +- Source/Engine/Scripting/ManagedCLR/MClass.h | 27 ++-- Source/Engine/Scripting/ManagedCLR/MUtils.cpp | 47 +++++-- Source/Engine/Scripting/ManagedCLR/MUtils.h | 1 + Source/Engine/Scripting/Runtime/DotNet.cpp | 35 +++++- Source/Engine/Scripting/Runtime/Mono.cpp | 5 + Source/Engine/Scripting/Runtime/None.cpp | 5 + .../Bindings/BindingsGenerator.Cpp.cs | 117 ++++++++++++++++-- 9 files changed, 216 insertions(+), 89 deletions(-) diff --git a/Source/Engine/Scripting/BinaryModule.cpp b/Source/Engine/Scripting/BinaryModule.cpp index bbcd7de57..1b254a0c6 100644 --- a/Source/Engine/Scripting/BinaryModule.cpp +++ b/Source/Engine/Scripting/BinaryModule.cpp @@ -32,6 +32,18 @@ ManagedBinaryModule* GetBinaryModuleCorlib() #endif } +MMethod* MClass::FindMethod(const char* name, int32 numParams, bool checkBaseClasses) const +{ + MMethod* method = GetMethod(name, numParams); + if (!method && checkBaseClasses) + { + MClass* base = GetBaseClass(); + if (base) + method = base->FindMethod(name, numParams, true); + } + return method; +} + ScriptingTypeHandle::ScriptingTypeHandle(const ScriptingTypeInitializer& initializer) : Module(initializer.Module) , TypeIndex(initializer.TypeIndex) @@ -835,61 +847,17 @@ namespace } return nullptr; } - - bool VariantTypeEquals(const VariantType& type, MType* mType, bool isOut = false) - { - MClass* mClass = MCore::Type::GetClass(mType); - MClass* variantClass = MUtils::GetClass(type); - if (variantClass != mClass) - { - // Hack for Vector2/3/4 which alias with Float2/3/4 or Double2/3/4 (depending on USE_LARGE_WORLDS) - const auto& stdTypes = *StdTypesContainer::Instance(); - if (mClass == stdTypes.Vector2Class && (type.Type == VariantType::Float2 || type.Type == VariantType::Double2)) - return true; - if (mClass == stdTypes.Vector3Class && (type.Type == VariantType::Float3 || type.Type == VariantType::Double3)) - return true; - if (mClass == stdTypes.Vector4Class && (type.Type == VariantType::Float4 || type.Type == VariantType::Double4)) - return true; - - return false; - } - return true; - } } #endif -MMethod* ManagedBinaryModule::FindMethod(MClass* mclass, const ScriptingTypeMethodSignature& signature) +MMethod* ManagedBinaryModule::FindMethod(const MClass* mclass, const ScriptingTypeMethodSignature& signature) { #if USE_CSHARP - if (!mclass) - return nullptr; - const auto& methods = mclass->GetMethods(); - for (MMethod* method : methods) - { - if (method->IsStatic() != signature.IsStatic) - continue; - if (method->GetName() != signature.Name) - continue; - if (method->GetParametersCount() != signature.Params.Count()) - continue; - bool isValid = true; - for (int32 paramIdx = 0; paramIdx < signature.Params.Count(); paramIdx++) - { - auto& param = signature.Params[paramIdx]; - MType* type = method->GetParameterType(paramIdx); - if (param.IsOut != method->GetParameterIsOut(paramIdx) || - !VariantTypeEquals(param.Type, type, param.IsOut)) - { - isValid = false; - break; - } - } - if (isValid && VariantTypeEquals(signature.ReturnType, method->GetReturnType())) - return method; - } -#endif + return mclass ? mclass->GetMethod(signature) : nullptr; +#else return nullptr; +#endif } #if USE_CSHARP diff --git a/Source/Engine/Scripting/BinaryModule.h b/Source/Engine/Scripting/BinaryModule.h index 1da35401b..db346b13c 100644 --- a/Source/Engine/Scripting/BinaryModule.h +++ b/Source/Engine/Scripting/BinaryModule.h @@ -23,7 +23,7 @@ struct ScriptingTypeMethodSignature StringAnsiView Name; VariantType ReturnType; - bool IsStatic; + bool IsStatic = false; Array> Params; }; @@ -322,7 +322,7 @@ public: #endif static ScriptingObject* ManagedObjectSpawn(const ScriptingObjectSpawnParams& params); - static MMethod* FindMethod(MClass* mclass, const ScriptingTypeMethodSignature& signature); + static MMethod* FindMethod(const MClass* mclass, const ScriptingTypeMethodSignature& signature); #if USE_CSHARP static ManagedBinaryModule* FindModule(const MClass* klass); static ScriptingTypeHandle FindType(const MClass* klass); diff --git a/Source/Engine/Scripting/ManagedCLR/MClass.h b/Source/Engine/Scripting/ManagedCLR/MClass.h index 61273dae7..ea26451c1 100644 --- a/Source/Engine/Scripting/ManagedCLR/MClass.h +++ b/Source/Engine/Scripting/ManagedCLR/MClass.h @@ -188,11 +188,11 @@ public: MClass* GetBaseClass() const; /// - /// Checks if this class is a sub class of the specified class (including any derived types). + /// Checks if this class is a subclass of the specified class (including any derived types). /// /// The class. /// True if check interfaces, otherwise just base class. - /// True if this class is a sub class of the specified class. + /// True if this class is a subclass of the specified class. bool IsSubClassOf(const MClass* klass, bool checkInterfaces = false) const; /// @@ -206,7 +206,7 @@ public: /// Checks is the provided object instance of this class' type. /// /// The object to check. - /// True if object is an instance the this class. + /// True if object is an instance this class. bool IsInstanceOfType(MObject* object) const; /// @@ -227,17 +227,7 @@ public: /// The method parameters count. /// True if check base classes when searching for the given method. /// The method or null if failed to find it. - MMethod* FindMethod(const char* name, int32 numParams, bool checkBaseClasses = true) const - { - MMethod* method = GetMethod(name, numParams); - if (!method && checkBaseClasses) - { - MClass* base = GetBaseClass(); - if (base) - method = base->FindMethod(name, numParams, true); - } - return method; - } + MMethod* FindMethod(const char* name, int32 numParams, bool checkBaseClasses = true) const; /// /// Returns an object referencing a method with the specified name and number of parameters. @@ -248,6 +238,13 @@ public: /// The method or null if failed to get it. MMethod* GetMethod(const char* name, int32 numParams = 0) const; + /// + /// Returns an object referencing a method with the specified signature. + /// + /// The method signature. + /// The method or null if failed to get it. + MMethod* GetMethod(const struct ScriptingTypeMethodSignature& signature) const; + /// /// Returns all methods belonging to this class. /// @@ -271,7 +268,7 @@ public: const Array& GetFields() const; /// - /// Returns an object referencing a event with the specified name. + /// Returns an object referencing an event with the specified name. /// /// The event name. /// The event object. diff --git a/Source/Engine/Scripting/ManagedCLR/MUtils.cpp b/Source/Engine/Scripting/ManagedCLR/MUtils.cpp index 69fb5aa08..0d7617e0d 100644 --- a/Source/Engine/Scripting/ManagedCLR/MUtils.cpp +++ b/Source/Engine/Scripting/ManagedCLR/MUtils.cpp @@ -845,7 +845,6 @@ MClass* MUtils::GetClass(const VariantType& value) auto mclass = Scripting::FindClass(StringAnsiView(value.TypeName)); if (mclass) return mclass; - const auto& stdTypes = *StdTypesContainer::Instance(); switch (value.Type) { case VariantType::Void: @@ -891,25 +890,25 @@ MClass* MUtils::GetClass(const VariantType& value) case VariantType::Double4: return Double4::TypeInitializer.GetClass(); case VariantType::Color: - return stdTypes.ColorClass; + return Color::TypeInitializer.GetClass(); case VariantType::Guid: - return stdTypes.GuidClass; + return GetBinaryModuleCorlib()->Assembly->GetClass("System.Guid"); case VariantType::Typename: - return stdTypes.TypeClass; + return GetBinaryModuleCorlib()->Assembly->GetClass("System.Type"); case VariantType::BoundingBox: - return stdTypes.BoundingBoxClass; + return BoundingBox::TypeInitializer.GetClass(); case VariantType::BoundingSphere: - return stdTypes.BoundingSphereClass; + return BoundingSphere::TypeInitializer.GetClass(); case VariantType::Quaternion: - return stdTypes.QuaternionClass; + return Quaternion::TypeInitializer.GetClass(); case VariantType::Transform: - return stdTypes.TransformClass; + return Transform::TypeInitializer.GetClass(); case VariantType::Rectangle: - return stdTypes.RectangleClass; + return Rectangle::TypeInitializer.GetClass(); case VariantType::Ray: - return stdTypes.RayClass; + return Ray::TypeInitializer.GetClass(); case VariantType::Matrix: - return stdTypes.MatrixClass; + return Matrix::TypeInitializer.GetClass(); case VariantType::Array: if (value.TypeName) { @@ -1202,8 +1201,7 @@ void* MUtils::VariantToManagedArgPtr(Variant& value, MType* type, bool& failed) if (value.Type.Type != VariantType::Array) return nullptr; MObject* object = BoxVariant(value); - auto typeStr = MCore::Type::ToString(type); - if (object && !MCore::Object::GetClass(object)->IsSubClassOf(MCore::Type::GetClass(type))) + if (object && MCore::Type::GetClass(type) != MCore::Array::GetArrayClass((MArray*)object)) object = nullptr; return object; } @@ -1238,6 +1236,29 @@ void* MUtils::VariantToManagedArgPtr(Variant& value, MType* type, bool& failed) return nullptr; } +bool MUtils::VariantTypeEquals(const VariantType& type, MType* mType, bool isOut) +{ + MClass* mClass = MCore::Type::GetClass(mType); + MClass* variantClass = MUtils::GetClass(type); + if (variantClass != mClass) + { + // Hack for Vector2/3/4 which alias with Float2/3/4 or Double2/3/4 (depending on USE_LARGE_WORLDS) + if (mClass->GetFullName() == StringAnsiView("FlaxEngine.Vector2", 18) && (type.Type == VariantType::Float2 || type.Type == VariantType::Double2)) + return true; + if (mClass->GetFullName() == StringAnsiView("FlaxEngine.Vector3", 18) && (type.Type == VariantType::Float3 || type.Type == VariantType::Double3)) + return true; + if (mClass->GetFullName() == StringAnsiView("FlaxEngine.Vector4", 18) && (type.Type == VariantType::Float4 || type.Type == VariantType::Double4)) + return true; + + // Arrays + if (type == VariantType::Array && type.GetElementType() == VariantType::Object) + return MCore::Type::GetType(mType) == MTypes::Array; + + return false; + } + return true; +} + MObject* MUtils::ToManaged(const Version& value) { #if USE_NETCORE diff --git a/Source/Engine/Scripting/ManagedCLR/MUtils.h b/Source/Engine/Scripting/ManagedCLR/MUtils.h index f642681d2..5598dbee0 100644 --- a/Source/Engine/Scripting/ManagedCLR/MUtils.h +++ b/Source/Engine/Scripting/ManagedCLR/MUtils.h @@ -621,6 +621,7 @@ namespace MUtils #endif extern void* VariantToManagedArgPtr(Variant& value, MType* type, bool& failed); + extern bool VariantTypeEquals(const VariantType& type, MType* mType, bool isOut = false); extern MObject* ToManaged(const Version& value); extern Version ToNative(MObject* value); diff --git a/Source/Engine/Scripting/Runtime/DotNet.cpp b/Source/Engine/Scripting/Runtime/DotNet.cpp index 4be0ce1a1..21d63e9d6 100644 --- a/Source/Engine/Scripting/Runtime/DotNet.cpp +++ b/Source/Engine/Scripting/Runtime/DotNet.cpp @@ -1062,10 +1062,39 @@ MClass* MClass::GetElementClass() const MMethod* MClass::GetMethod(const char* name, int32 numParams) const { GetMethods(); - for (int32 i = 0; i < _methods.Count(); i++) + for (MMethod* method : _methods) { - if (_methods[i]->GetParametersCount() == numParams && _methods[i]->GetName() == name) - return _methods[i]; + if (method->GetParametersCount() == numParams && method->GetName() == name) + return method; + } + return nullptr; +} + +MMethod* MClass::GetMethod(const ScriptingTypeMethodSignature& signature) const +{ + GetMethods(); + for (MMethod* method : _methods) + { + if (method->IsStatic() != signature.IsStatic) + continue; + if (method->GetName() != signature.Name) + continue; + if (method->GetParametersCount() != signature.Params.Count()) + continue; + bool isValid = true; + for (int32 paramIdx = 0; paramIdx < signature.Params.Count(); paramIdx++) + { + auto& param = signature.Params[paramIdx]; + MType* type = method->GetParameterType(paramIdx); + if (param.IsOut != method->GetParameterIsOut(paramIdx) || + !MUtils::VariantTypeEquals(param.Type, type, param.IsOut)) + { + isValid = false; + break; + } + } + if (isValid && (signature.ReturnType.Type == VariantType::Null || MUtils::VariantTypeEquals(signature.ReturnType, method->GetReturnType()))) + return method; } return nullptr; } diff --git a/Source/Engine/Scripting/Runtime/Mono.cpp b/Source/Engine/Scripting/Runtime/Mono.cpp index 06392f932..1449d405f 100644 --- a/Source/Engine/Scripting/Runtime/Mono.cpp +++ b/Source/Engine/Scripting/Runtime/Mono.cpp @@ -1357,6 +1357,11 @@ MMethod* MClass::GetMethod(const char* name, int32 numParams) const return method; } +MMethod* MClass::GetMethod(const ScriptingTypeMethodSignature& signature) const +{ + return GetMethod(signature.Name.Get(), signature.Params.Count()); +} + const Array& MClass::GetMethods() const { if (_hasCachedMethods) diff --git a/Source/Engine/Scripting/Runtime/None.cpp b/Source/Engine/Scripting/Runtime/None.cpp index 1ddaeae8e..178e98444 100644 --- a/Source/Engine/Scripting/Runtime/None.cpp +++ b/Source/Engine/Scripting/Runtime/None.cpp @@ -363,6 +363,11 @@ MMethod* MClass::GetMethod(const char* name, int32 numParams) const return nullptr; } +MMethod* MClass::GetMethod(const ScriptingTypeMethodSignature& signature) const +{ + return nullptr; +} + const Array& MClass::GetMethods() const { _hasCachedMethods = true; diff --git a/Source/Tools/Flax.Build/Bindings/BindingsGenerator.Cpp.cs b/Source/Tools/Flax.Build/Bindings/BindingsGenerator.Cpp.cs index 4dbbe576c..50d3aa85c 100644 --- a/Source/Tools/Flax.Build/Bindings/BindingsGenerator.Cpp.cs +++ b/Source/Tools/Flax.Build/Bindings/BindingsGenerator.Cpp.cs @@ -215,6 +215,81 @@ namespace Flax.Build.Bindings return $"Variant({value})"; } + public static string GenerateCppWrapperNativeToVariantType(BuildData buildData, TypeInfo typeInfo, ApiTypeInfo caller) + { + // In-built types + switch (typeInfo.Type) + { + case "void": return "VariantType(VariantType::Void)"; + case "bool": return "VariantType(VariantType::Bool)"; + case "int": + case "int32": return "VariantType(VariantType::Int)"; + case "uint": + case "uint32": return "VariantType(VariantType::Uint)"; + case "int64": return "VariantType(VariantType::Int64)"; + case "uint64": return "VariantType(VariantType::Uint64)"; + case "Real": + case "float": return "VariantType(VariantType::Float)"; + case "double": return "VariantType(VariantType::Double)"; + case "StringAnsiView": + case "StringAnsi": + case "StringView": + case "String": return "VariantType(VariantType::String)"; + case "Guid": return "VariantType(VariantType::Guid)"; + case "Asset": return "VariantType(VariantType::Asset)"; + case "Float2": return "VariantType(VariantType::Float2)"; + case "Float3": return "VariantType(VariantType::Float3)"; + case "Float4": return "VariantType(VariantType::Float4)"; + case "Double2": return "VariantType(VariantType::Double2)"; + case "Double3": return "VariantType(VariantType::Double3)"; + case "Double4": return "VariantType(VariantType::Double4)"; + case "Vector2": return "VariantType(VariantType::Vector2)"; + case "Vector3": return "VariantType(VariantType::Vector3)"; + case "Vector4": return "VariantType(VariantType::Vector4)"; + case "Int2": return "VariantType(VariantType::Int2)"; + case "Int3": return "VariantType(VariantType::Int3)"; + case "Int4": return "VariantType(VariantType::Int4)"; + case "Color": return "VariantType(VariantType::Color)"; + case "BoundingBox": return "VariantType(VariantType::BoundingBox)"; + case "BoundingSphere": return "VariantType(VariantType::BoundingSphere)"; + case "Quaternion": return "VariantType(VariantType::Quaternion)"; + case "Transform": return "VariantType(VariantType::Transform)"; + case "Rectangle": return "VariantType(VariantType::Rectangle)"; + case "Ray": return "VariantType(VariantType::Ray)"; + case "Matrix": return "VariantType(VariantType::Matrix)"; + case "Type": return "VariantType(VariantType::Typename)"; + } + + // Array + if (typeInfo.IsArray) + return "VariantType(VariantType::Array)"; + if ((typeInfo.Type == "Array" || typeInfo.Type == "Span") && typeInfo.GenericArgs != null) + { + var elementType = FindApiTypeInfo(buildData, typeInfo.GenericArgs[0], caller); + var elementName = $"{(elementType != null ? elementType.FullNameManaged : typeInfo.GenericArgs[0].Type)}[]"; + return $"VariantType(VariantType::Array, StringAnsiView(\"{elementName}\", {elementName.Length}))"; + } + if (typeInfo.Type == "Dictionary" && typeInfo.GenericArgs != null) + return "VariantType(VariantType::Dictionary)"; + + // Scripting type + var apiType = FindApiTypeInfo(buildData, typeInfo, caller); + if (apiType != null) + { + // TODO: optimize VariantType for explicitly defined types to use static name and less mem/allocs + var fullname = apiType.FullNameManaged; + if (apiType.IsEnum) + return $"VariantType(VariantType::Enum, StringAnsiView(\"{fullname}\", {fullname.Length}))"; + if (apiType.IsStruct) + return $"VariantType(VariantType::Structure, StringAnsiView(\"{fullname}\", {fullname.Length}))"; + if (apiType.IsClass) + return $"VariantType(VariantType::Object, StringAnsiView(\"{fullname}\", {fullname.Length}))"; + } + + // Unknown + return "VariantType()"; + } + public static string GenerateCppWrapperVariantToNative(BuildData buildData, TypeInfo typeInfo, ApiTypeInfo caller, string value) { if (typeInfo.Type == "Variant") @@ -393,7 +468,7 @@ namespace Flax.Build.Bindings return "Scripting::FindClass(\"" + managedType + "\")"; } - private static string GenerateCppGetNativeType(BuildData buildData, TypeInfo typeInfo, ApiTypeInfo caller, FunctionInfo functionInfo) + private static string GenerateCppGetNativeType(BuildData buildData, TypeInfo typeInfo, ApiTypeInfo caller, FunctionInfo functionInfo = null) { CppIncludeFiles.Add("Engine/Scripting/ManagedCLR/MClass.h"); @@ -1751,7 +1826,23 @@ namespace Flax.Build.Bindings { if (!functionInfo.IsVirtual) continue; - contents.AppendLine($" scriptVTable[{scriptVTableIndex++}] = mclass->GetMethod(\"{functionInfo.Name}\", {functionInfo.Parameters.Count});"); + + // Don't use exact signature for parameter-less methods or the ones without duplicates + if (functionInfo.Parameters.Count == 0 || !classInfo.Functions.Any(x => x != functionInfo && x.Parameters.Count == functionInfo.Parameters.Count && x.Name == functionInfo.Name)) + contents.AppendLine($" scriptVTable[{scriptVTableIndex++}] = mclass->GetMethod(\"{functionInfo.Name}\", {functionInfo.Parameters.Count});"); + else + { + contents.AppendLine(" {"); + contents.AppendLine(" ScriptingTypeMethodSignature signature;"); + contents.AppendLine($" signature.Name = StringAnsiView(\"{functionInfo.Name}\", {functionInfo.Name.Length});"); + contents.AppendLine($" signature.Params.Resize({functionInfo.Parameters.Count});"); + for (var i = 0; i < functionInfo.Parameters.Count; i++) + contents.AppendLine($" signature.Params[{i}] = {{ {GenerateCppWrapperNativeToVariantType(buildData, functionInfo.Parameters[i].Type, classInfo)}, {(functionInfo.Parameters[i].IsOut ? "true" : "false")} }};"); + contents.AppendLine($" scriptVTable[{scriptVTableIndex++}] = mclass->GetMethod(signature);"); + if (buildData.Configuration != TargetConfiguration.Release) + contents.AppendLine($" ASSERT(scriptVTable[{scriptVTableIndex - 1}]);"); + contents.AppendLine(" }"); + } } contents.AppendLine(" }"); contents.AppendLine(""); @@ -2672,10 +2763,22 @@ namespace Flax.Build.Bindings { contents.AppendLine(" Variant* parameters = nullptr;"); } + if (functionInfo.Parameters.Count != 0) + { + // Build method signature to find method using exact parameter types to match on name collisions + contents.AppendLine(" ScriptingTypeMethodSignature signature;"); + contents.AppendLine($" signature.Name = StringAnsiView(\"{functionInfo.Name}\", {functionInfo.Name.Length});"); + contents.AppendLine($" signature.Params.Resize({functionInfo.Parameters.Count});"); + for (var i = 0; i < functionInfo.Parameters.Count; i++) + contents.AppendLine($" signature.Params[{i}] = {{ parameters[{i}].Type, {(functionInfo.Parameters[i].IsOut ? "true" : "false")} }};"); + } contents.AppendLine(" auto typeHandle = Object->GetTypeHandle();"); contents.AppendLine(" while (typeHandle)"); contents.AppendLine(" {"); - contents.AppendLine($" auto method = typeHandle.Module->FindMethod(typeHandle, StringAnsiView(\"{functionInfo.Name}\", {functionInfo.Name.Length}), {functionInfo.Parameters.Count});"); + if (functionInfo.Parameters.Count == 0) + contents.AppendLine($" auto method = typeHandle.Module->FindMethod(typeHandle, StringAnsiView(\"{functionInfo.Name}\", {functionInfo.Name.Length}), {functionInfo.Parameters.Count});"); + else + contents.AppendLine(" auto method = typeHandle.Module->FindMethod(typeHandle, signature);"); contents.AppendLine(" if (method)"); contents.AppendLine(" {"); contents.AppendLine(" Variant __result;"); @@ -2948,7 +3051,9 @@ namespace Flax.Build.Bindings header.Append($"{wrapperName}Array(const {valueType}* v, const int32 length)").AppendLine(); header.Append('{').AppendLine(); header.Append(" Variant result;").AppendLine(); - header.Append(" result.SetType(VariantType(VariantType::Array));").AppendLine(); + var apiType = FindApiTypeInfo(buildData, valueType, moduleInfo); + var elementName = $"{(apiType != null ? apiType.FullNameManaged : valueType.Type)}[]"; + header.Append($" result.SetType(VariantType(VariantType::Array, StringAnsiView(\"{elementName}\", {elementName.Length})));").AppendLine(); header.Append(" auto* array = reinterpret_cast*>(result.AsData);").AppendLine(); header.Append(" array->Resize(length);").AppendLine(); header.Append(" for (int32 i = 0; i < length; i++)").AppendLine(); @@ -3316,13 +3421,9 @@ namespace Flax.Build.Bindings contents.AppendLine($"extern \"C\" BinaryModule* GetBinaryModule{binaryModuleName}()"); contents.AppendLine("{"); if (useCSharp) - { contents.AppendLine($" static NativeBinaryModule module(\"{binaryModuleName}\");"); - } else - { contents.AppendLine($" static NativeOnlyBinaryModule module(\"{binaryModuleName}\");"); - } contents.AppendLine(" return &module;"); contents.AppendLine("}"); if (project.VersionControlBranch.Length != 0)