Fix scripting bindings in searching virtual methods to invoke when there is a name and parameter count collision

#3649
This commit is contained in:
Wojtek Figat
2026-03-30 19:54:47 +02:00
parent 767854a2af
commit bf20f5d2bf
9 changed files with 216 additions and 89 deletions

View File

@@ -32,6 +32,18 @@ ManagedBinaryModule* GetBinaryModuleCorlib()
#endif
}
MMethod* MClass::FindMethod(const char* name, int32 numParams, bool checkBaseClasses) const
{
MMethod* method = GetMethod(name, numParams);
if (!method && checkBaseClasses)
{
MClass* base = GetBaseClass();
if (base)
method = base->FindMethod(name, numParams, true);
}
return method;
}
ScriptingTypeHandle::ScriptingTypeHandle(const ScriptingTypeInitializer& initializer)
: Module(initializer.Module)
, TypeIndex(initializer.TypeIndex)
@@ -835,61 +847,17 @@ namespace
}
return nullptr;
}
bool VariantTypeEquals(const VariantType& type, MType* mType, bool isOut = false)
{
MClass* mClass = MCore::Type::GetClass(mType);
MClass* variantClass = MUtils::GetClass(type);
if (variantClass != mClass)
{
// Hack for Vector2/3/4 which alias with Float2/3/4 or Double2/3/4 (depending on USE_LARGE_WORLDS)
const auto& stdTypes = *StdTypesContainer::Instance();
if (mClass == stdTypes.Vector2Class && (type.Type == VariantType::Float2 || type.Type == VariantType::Double2))
return true;
if (mClass == stdTypes.Vector3Class && (type.Type == VariantType::Float3 || type.Type == VariantType::Double3))
return true;
if (mClass == stdTypes.Vector4Class && (type.Type == VariantType::Float4 || type.Type == VariantType::Double4))
return true;
return false;
}
return true;
}
}
#endif
MMethod* ManagedBinaryModule::FindMethod(MClass* mclass, const ScriptingTypeMethodSignature& signature)
MMethod* ManagedBinaryModule::FindMethod(const MClass* mclass, const ScriptingTypeMethodSignature& signature)
{
#if USE_CSHARP
if (!mclass)
return nullptr;
const auto& methods = mclass->GetMethods();
for (MMethod* method : methods)
{
if (method->IsStatic() != signature.IsStatic)
continue;
if (method->GetName() != signature.Name)
continue;
if (method->GetParametersCount() != signature.Params.Count())
continue;
bool isValid = true;
for (int32 paramIdx = 0; paramIdx < signature.Params.Count(); paramIdx++)
{
auto& param = signature.Params[paramIdx];
MType* type = method->GetParameterType(paramIdx);
if (param.IsOut != method->GetParameterIsOut(paramIdx) ||
!VariantTypeEquals(param.Type, type, param.IsOut))
{
isValid = false;
break;
}
}
if (isValid && VariantTypeEquals(signature.ReturnType, method->GetReturnType()))
return method;
}
#endif
return mclass ? mclass->GetMethod(signature) : nullptr;
#else
return nullptr;
#endif
}
#if USE_CSHARP

View File

@@ -23,7 +23,7 @@ struct ScriptingTypeMethodSignature
StringAnsiView Name;
VariantType ReturnType;
bool IsStatic;
bool IsStatic = false;
Array<Param, InlinedAllocation<16>> Params;
};
@@ -322,7 +322,7 @@ public:
#endif
static ScriptingObject* ManagedObjectSpawn(const ScriptingObjectSpawnParams& params);
static MMethod* FindMethod(MClass* mclass, const ScriptingTypeMethodSignature& signature);
static MMethod* FindMethod(const MClass* mclass, const ScriptingTypeMethodSignature& signature);
#if USE_CSHARP
static ManagedBinaryModule* FindModule(const MClass* klass);
static ScriptingTypeHandle FindType(const MClass* klass);

View File

@@ -188,11 +188,11 @@ public:
MClass* GetBaseClass() const;
/// <summary>
/// Checks if this class is a sub class of the specified class (including any derived types).
/// Checks if this class is a subclass of the specified class (including any derived types).
/// </summary>
/// <param name="klass">The class.</param>
/// <param name="checkInterfaces">True if check interfaces, otherwise just base class.</param>
/// <returns>True if this class is a sub class of the specified class.</returns>
/// <returns>True if this class is a subclass of the specified class.</returns>
bool IsSubClassOf(const MClass* klass, bool checkInterfaces = false) const;
/// <summary>
@@ -206,7 +206,7 @@ public:
/// Checks is the provided object instance of this class' type.
/// </summary>
/// <param name="object">The object to check.</param>
/// <returns>True if object is an instance the this class.</returns>
/// <returns>True if object is an instance this class.</returns>
bool IsInstanceOfType(MObject* object) const;
/// <summary>
@@ -227,17 +227,7 @@ public:
/// <param name="numParams">The method parameters count.</param>
/// <param name="checkBaseClasses">True if check base classes when searching for the given method.</param>
/// <returns>The method or null if failed to find it.</returns>
MMethod* FindMethod(const char* name, int32 numParams, bool checkBaseClasses = true) const
{
MMethod* method = GetMethod(name, numParams);
if (!method && checkBaseClasses)
{
MClass* base = GetBaseClass();
if (base)
method = base->FindMethod(name, numParams, true);
}
return method;
}
MMethod* FindMethod(const char* name, int32 numParams, bool checkBaseClasses = true) const;
/// <summary>
/// Returns an object referencing a method with the specified name and number of parameters.
@@ -248,6 +238,13 @@ public:
/// <returns>The method or null if failed to get it.</returns>
MMethod* GetMethod(const char* name, int32 numParams = 0) const;
/// <summary>
/// Returns an object referencing a method with the specified signature.
/// </summary>
/// <param name="signature">The method signature.</param>
/// <returns>The method or null if failed to get it.</returns>
MMethod* GetMethod(const struct ScriptingTypeMethodSignature& signature) const;
/// <summary>
/// Returns all methods belonging to this class.
/// </summary>
@@ -271,7 +268,7 @@ public:
const Array<MField*, ArenaAllocation>& GetFields() const;
/// <summary>
/// Returns an object referencing a event with the specified name.
/// Returns an object referencing an event with the specified name.
/// </summary>
/// <param name="name">The event name.</param>
/// <returns>The event object.</returns>

View File

@@ -845,7 +845,6 @@ MClass* MUtils::GetClass(const VariantType& value)
auto mclass = Scripting::FindClass(StringAnsiView(value.TypeName));
if (mclass)
return mclass;
const auto& stdTypes = *StdTypesContainer::Instance();
switch (value.Type)
{
case VariantType::Void:
@@ -891,25 +890,25 @@ MClass* MUtils::GetClass(const VariantType& value)
case VariantType::Double4:
return Double4::TypeInitializer.GetClass();
case VariantType::Color:
return stdTypes.ColorClass;
return Color::TypeInitializer.GetClass();
case VariantType::Guid:
return stdTypes.GuidClass;
return GetBinaryModuleCorlib()->Assembly->GetClass("System.Guid");
case VariantType::Typename:
return stdTypes.TypeClass;
return GetBinaryModuleCorlib()->Assembly->GetClass("System.Type");
case VariantType::BoundingBox:
return stdTypes.BoundingBoxClass;
return BoundingBox::TypeInitializer.GetClass();
case VariantType::BoundingSphere:
return stdTypes.BoundingSphereClass;
return BoundingSphere::TypeInitializer.GetClass();
case VariantType::Quaternion:
return stdTypes.QuaternionClass;
return Quaternion::TypeInitializer.GetClass();
case VariantType::Transform:
return stdTypes.TransformClass;
return Transform::TypeInitializer.GetClass();
case VariantType::Rectangle:
return stdTypes.RectangleClass;
return Rectangle::TypeInitializer.GetClass();
case VariantType::Ray:
return stdTypes.RayClass;
return Ray::TypeInitializer.GetClass();
case VariantType::Matrix:
return stdTypes.MatrixClass;
return Matrix::TypeInitializer.GetClass();
case VariantType::Array:
if (value.TypeName)
{
@@ -1202,8 +1201,7 @@ void* MUtils::VariantToManagedArgPtr(Variant& value, MType* type, bool& failed)
if (value.Type.Type != VariantType::Array)
return nullptr;
MObject* object = BoxVariant(value);
auto typeStr = MCore::Type::ToString(type);
if (object && !MCore::Object::GetClass(object)->IsSubClassOf(MCore::Type::GetClass(type)))
if (object && MCore::Type::GetClass(type) != MCore::Array::GetArrayClass((MArray*)object))
object = nullptr;
return object;
}
@@ -1238,6 +1236,29 @@ void* MUtils::VariantToManagedArgPtr(Variant& value, MType* type, bool& failed)
return nullptr;
}
bool MUtils::VariantTypeEquals(const VariantType& type, MType* mType, bool isOut)
{
MClass* mClass = MCore::Type::GetClass(mType);
MClass* variantClass = MUtils::GetClass(type);
if (variantClass != mClass)
{
// Hack for Vector2/3/4 which alias with Float2/3/4 or Double2/3/4 (depending on USE_LARGE_WORLDS)
if (mClass->GetFullName() == StringAnsiView("FlaxEngine.Vector2", 18) && (type.Type == VariantType::Float2 || type.Type == VariantType::Double2))
return true;
if (mClass->GetFullName() == StringAnsiView("FlaxEngine.Vector3", 18) && (type.Type == VariantType::Float3 || type.Type == VariantType::Double3))
return true;
if (mClass->GetFullName() == StringAnsiView("FlaxEngine.Vector4", 18) && (type.Type == VariantType::Float4 || type.Type == VariantType::Double4))
return true;
// Arrays
if (type == VariantType::Array && type.GetElementType() == VariantType::Object)
return MCore::Type::GetType(mType) == MTypes::Array;
return false;
}
return true;
}
MObject* MUtils::ToManaged(const Version& value)
{
#if USE_NETCORE

View File

@@ -621,6 +621,7 @@ namespace MUtils
#endif
extern void* VariantToManagedArgPtr(Variant& value, MType* type, bool& failed);
extern bool VariantTypeEquals(const VariantType& type, MType* mType, bool isOut = false);
extern MObject* ToManaged(const Version& value);
extern Version ToNative(MObject* value);

View File

@@ -1062,10 +1062,39 @@ MClass* MClass::GetElementClass() const
MMethod* MClass::GetMethod(const char* name, int32 numParams) const
{
GetMethods();
for (int32 i = 0; i < _methods.Count(); i++)
for (MMethod* method : _methods)
{
if (_methods[i]->GetParametersCount() == numParams && _methods[i]->GetName() == name)
return _methods[i];
if (method->GetParametersCount() == numParams && method->GetName() == name)
return method;
}
return nullptr;
}
MMethod* MClass::GetMethod(const ScriptingTypeMethodSignature& signature) const
{
GetMethods();
for (MMethod* method : _methods)
{
if (method->IsStatic() != signature.IsStatic)
continue;
if (method->GetName() != signature.Name)
continue;
if (method->GetParametersCount() != signature.Params.Count())
continue;
bool isValid = true;
for (int32 paramIdx = 0; paramIdx < signature.Params.Count(); paramIdx++)
{
auto& param = signature.Params[paramIdx];
MType* type = method->GetParameterType(paramIdx);
if (param.IsOut != method->GetParameterIsOut(paramIdx) ||
!MUtils::VariantTypeEquals(param.Type, type, param.IsOut))
{
isValid = false;
break;
}
}
if (isValid && (signature.ReturnType.Type == VariantType::Null || MUtils::VariantTypeEquals(signature.ReturnType, method->GetReturnType())))
return method;
}
return nullptr;
}

View File

@@ -1357,6 +1357,11 @@ MMethod* MClass::GetMethod(const char* name, int32 numParams) const
return method;
}
MMethod* MClass::GetMethod(const ScriptingTypeMethodSignature& signature) const
{
return GetMethod(signature.Name.Get(), signature.Params.Count());
}
const Array<MMethod*>& MClass::GetMethods() const
{
if (_hasCachedMethods)

View File

@@ -363,6 +363,11 @@ MMethod* MClass::GetMethod(const char* name, int32 numParams) const
return nullptr;
}
MMethod* MClass::GetMethod(const ScriptingTypeMethodSignature& signature) const
{
return nullptr;
}
const Array<MMethod*, ArenaAllocation>& MClass::GetMethods() const
{
_hasCachedMethods = true;

View File

@@ -215,6 +215,81 @@ namespace Flax.Build.Bindings
return $"Variant({value})";
}
public static string GenerateCppWrapperNativeToVariantType(BuildData buildData, TypeInfo typeInfo, ApiTypeInfo caller)
{
// In-built types
switch (typeInfo.Type)
{
case "void": return "VariantType(VariantType::Void)";
case "bool": return "VariantType(VariantType::Bool)";
case "int":
case "int32": return "VariantType(VariantType::Int)";
case "uint":
case "uint32": return "VariantType(VariantType::Uint)";
case "int64": return "VariantType(VariantType::Int64)";
case "uint64": return "VariantType(VariantType::Uint64)";
case "Real":
case "float": return "VariantType(VariantType::Float)";
case "double": return "VariantType(VariantType::Double)";
case "StringAnsiView":
case "StringAnsi":
case "StringView":
case "String": return "VariantType(VariantType::String)";
case "Guid": return "VariantType(VariantType::Guid)";
case "Asset": return "VariantType(VariantType::Asset)";
case "Float2": return "VariantType(VariantType::Float2)";
case "Float3": return "VariantType(VariantType::Float3)";
case "Float4": return "VariantType(VariantType::Float4)";
case "Double2": return "VariantType(VariantType::Double2)";
case "Double3": return "VariantType(VariantType::Double3)";
case "Double4": return "VariantType(VariantType::Double4)";
case "Vector2": return "VariantType(VariantType::Vector2)";
case "Vector3": return "VariantType(VariantType::Vector3)";
case "Vector4": return "VariantType(VariantType::Vector4)";
case "Int2": return "VariantType(VariantType::Int2)";
case "Int3": return "VariantType(VariantType::Int3)";
case "Int4": return "VariantType(VariantType::Int4)";
case "Color": return "VariantType(VariantType::Color)";
case "BoundingBox": return "VariantType(VariantType::BoundingBox)";
case "BoundingSphere": return "VariantType(VariantType::BoundingSphere)";
case "Quaternion": return "VariantType(VariantType::Quaternion)";
case "Transform": return "VariantType(VariantType::Transform)";
case "Rectangle": return "VariantType(VariantType::Rectangle)";
case "Ray": return "VariantType(VariantType::Ray)";
case "Matrix": return "VariantType(VariantType::Matrix)";
case "Type": return "VariantType(VariantType::Typename)";
}
// Array
if (typeInfo.IsArray)
return "VariantType(VariantType::Array)";
if ((typeInfo.Type == "Array" || typeInfo.Type == "Span") && typeInfo.GenericArgs != null)
{
var elementType = FindApiTypeInfo(buildData, typeInfo.GenericArgs[0], caller);
var elementName = $"{(elementType != null ? elementType.FullNameManaged : typeInfo.GenericArgs[0].Type)}[]";
return $"VariantType(VariantType::Array, StringAnsiView(\"{elementName}\", {elementName.Length}))";
}
if (typeInfo.Type == "Dictionary" && typeInfo.GenericArgs != null)
return "VariantType(VariantType::Dictionary)";
// Scripting type
var apiType = FindApiTypeInfo(buildData, typeInfo, caller);
if (apiType != null)
{
// TODO: optimize VariantType for explicitly defined types to use static name and less mem/allocs
var fullname = apiType.FullNameManaged;
if (apiType.IsEnum)
return $"VariantType(VariantType::Enum, StringAnsiView(\"{fullname}\", {fullname.Length}))";
if (apiType.IsStruct)
return $"VariantType(VariantType::Structure, StringAnsiView(\"{fullname}\", {fullname.Length}))";
if (apiType.IsClass)
return $"VariantType(VariantType::Object, StringAnsiView(\"{fullname}\", {fullname.Length}))";
}
// Unknown
return "VariantType()";
}
public static string GenerateCppWrapperVariantToNative(BuildData buildData, TypeInfo typeInfo, ApiTypeInfo caller, string value)
{
if (typeInfo.Type == "Variant")
@@ -393,7 +468,7 @@ namespace Flax.Build.Bindings
return "Scripting::FindClass(\"" + managedType + "\")";
}
private static string GenerateCppGetNativeType(BuildData buildData, TypeInfo typeInfo, ApiTypeInfo caller, FunctionInfo functionInfo)
private static string GenerateCppGetNativeType(BuildData buildData, TypeInfo typeInfo, ApiTypeInfo caller, FunctionInfo functionInfo = null)
{
CppIncludeFiles.Add("Engine/Scripting/ManagedCLR/MClass.h");
@@ -1751,7 +1826,23 @@ namespace Flax.Build.Bindings
{
if (!functionInfo.IsVirtual)
continue;
contents.AppendLine($" scriptVTable[{scriptVTableIndex++}] = mclass->GetMethod(\"{functionInfo.Name}\", {functionInfo.Parameters.Count});");
// Don't use exact signature for parameter-less methods or the ones without duplicates
if (functionInfo.Parameters.Count == 0 || !classInfo.Functions.Any(x => x != functionInfo && x.Parameters.Count == functionInfo.Parameters.Count && x.Name == functionInfo.Name))
contents.AppendLine($" scriptVTable[{scriptVTableIndex++}] = mclass->GetMethod(\"{functionInfo.Name}\", {functionInfo.Parameters.Count});");
else
{
contents.AppendLine(" {");
contents.AppendLine(" ScriptingTypeMethodSignature signature;");
contents.AppendLine($" signature.Name = StringAnsiView(\"{functionInfo.Name}\", {functionInfo.Name.Length});");
contents.AppendLine($" signature.Params.Resize({functionInfo.Parameters.Count});");
for (var i = 0; i < functionInfo.Parameters.Count; i++)
contents.AppendLine($" signature.Params[{i}] = {{ {GenerateCppWrapperNativeToVariantType(buildData, functionInfo.Parameters[i].Type, classInfo)}, {(functionInfo.Parameters[i].IsOut ? "true" : "false")} }};");
contents.AppendLine($" scriptVTable[{scriptVTableIndex++}] = mclass->GetMethod(signature);");
if (buildData.Configuration != TargetConfiguration.Release)
contents.AppendLine($" ASSERT(scriptVTable[{scriptVTableIndex - 1}]);");
contents.AppendLine(" }");
}
}
contents.AppendLine(" }");
contents.AppendLine("");
@@ -2672,10 +2763,22 @@ namespace Flax.Build.Bindings
{
contents.AppendLine(" Variant* parameters = nullptr;");
}
if (functionInfo.Parameters.Count != 0)
{
// Build method signature to find method using exact parameter types to match on name collisions
contents.AppendLine(" ScriptingTypeMethodSignature signature;");
contents.AppendLine($" signature.Name = StringAnsiView(\"{functionInfo.Name}\", {functionInfo.Name.Length});");
contents.AppendLine($" signature.Params.Resize({functionInfo.Parameters.Count});");
for (var i = 0; i < functionInfo.Parameters.Count; i++)
contents.AppendLine($" signature.Params[{i}] = {{ parameters[{i}].Type, {(functionInfo.Parameters[i].IsOut ? "true" : "false")} }};");
}
contents.AppendLine(" auto typeHandle = Object->GetTypeHandle();");
contents.AppendLine(" while (typeHandle)");
contents.AppendLine(" {");
contents.AppendLine($" auto method = typeHandle.Module->FindMethod(typeHandle, StringAnsiView(\"{functionInfo.Name}\", {functionInfo.Name.Length}), {functionInfo.Parameters.Count});");
if (functionInfo.Parameters.Count == 0)
contents.AppendLine($" auto method = typeHandle.Module->FindMethod(typeHandle, StringAnsiView(\"{functionInfo.Name}\", {functionInfo.Name.Length}), {functionInfo.Parameters.Count});");
else
contents.AppendLine(" auto method = typeHandle.Module->FindMethod(typeHandle, signature);");
contents.AppendLine(" if (method)");
contents.AppendLine(" {");
contents.AppendLine(" Variant __result;");
@@ -2948,7 +3051,9 @@ namespace Flax.Build.Bindings
header.Append($"{wrapperName}Array(const {valueType}* v, const int32 length)").AppendLine();
header.Append('{').AppendLine();
header.Append(" Variant result;").AppendLine();
header.Append(" result.SetType(VariantType(VariantType::Array));").AppendLine();
var apiType = FindApiTypeInfo(buildData, valueType, moduleInfo);
var elementName = $"{(apiType != null ? apiType.FullNameManaged : valueType.Type)}[]";
header.Append($" result.SetType(VariantType(VariantType::Array, StringAnsiView(\"{elementName}\", {elementName.Length})));").AppendLine();
header.Append(" auto* array = reinterpret_cast<Array<Variant, HeapAllocation>*>(result.AsData);").AppendLine();
header.Append(" array->Resize(length);").AppendLine();
header.Append(" for (int32 i = 0; i < length; i++)").AppendLine();
@@ -3316,13 +3421,9 @@ namespace Flax.Build.Bindings
contents.AppendLine($"extern \"C\" BinaryModule* GetBinaryModule{binaryModuleName}()");
contents.AppendLine("{");
if (useCSharp)
{
contents.AppendLine($" static NativeBinaryModule module(\"{binaryModuleName}\");");
}
else
{
contents.AppendLine($" static NativeOnlyBinaryModule module(\"{binaryModuleName}\");");
}
contents.AppendLine(" return &module;");
contents.AppendLine("}");
if (project.VersionControlBranch.Length != 0)