Add support for interfaces in scripting API (cross language support C++/C#/VS)

This commit is contained in:
Wojtek Figat
2021-10-04 12:22:28 +02:00
parent 147e5ada46
commit c3c0a4ef0d
16 changed files with 1039 additions and 355 deletions

View File

@@ -81,6 +81,7 @@ ScriptingType::ScriptingType()
{
Script.Spawn = nullptr;
Script.VTable = nullptr;
Script.InterfacesOffsets = nullptr;
Script.ScriptVTable = nullptr;
Script.ScriptVTableBase = nullptr;
Script.SetupScriptVTable = nullptr;
@@ -101,6 +102,7 @@ ScriptingType::ScriptingType(const StringAnsiView& fullname, BinaryModule* modul
{
Script.Spawn = spawn;
Script.VTable = nullptr;
Script.InterfacesOffsets = nullptr;
Script.ScriptVTable = nullptr;
Script.ScriptVTableBase = nullptr;
Script.SetupScriptVTable = setupScriptVTable;
@@ -120,6 +122,7 @@ ScriptingType::ScriptingType(const StringAnsiView& fullname, BinaryModule* modul
{
Script.Spawn = spawn;
Script.VTable = nullptr;
Script.InterfacesOffsets = nullptr;
Script.ScriptVTable = nullptr;
Script.ScriptVTableBase = nullptr;
Script.SetupScriptVTable = setupScriptVTable;
@@ -160,16 +163,19 @@ ScriptingType::ScriptingType(const StringAnsiView& fullname, BinaryModule* modul
Struct.SetField = setField;
}
ScriptingType::ScriptingType(const StringAnsiView& fullname, BinaryModule* module, InitRuntimeHandler initRuntime, ScriptingTypeInitializer* baseType, const InterfaceImplementation* interfaces)
ScriptingType::ScriptingType(const StringAnsiView& fullname, BinaryModule* module, InitRuntimeHandler initRuntime, SetupScriptVTableHandler setupScriptVTable, SetupScriptObjectVTableHandler setupScriptObjectVTable, GetInterfaceWrapper getInterfaceWrapper)
: ManagedClass(nullptr)
, Module(module)
, InitRuntime(initRuntime)
, Fullname(fullname)
, Type(ScriptingTypes::Interface)
, BaseTypePtr(baseType)
, Interfaces(interfaces)
, BaseTypePtr(nullptr)
, Interfaces(nullptr)
, Size(0)
{
Interface.SetupScriptVTable = setupScriptVTable;
Interface.SetupScriptObjectVTable = setupScriptObjectVTable;
Interface.GetInterfaceWrapper = getInterfaceWrapper;
}
ScriptingType::ScriptingType(const ScriptingType& other)
@@ -188,6 +194,7 @@ ScriptingType::ScriptingType(const ScriptingType& other)
case ScriptingTypes::Script:
Script.Spawn = other.Script.Spawn;
Script.VTable = nullptr;
Script.InterfacesOffsets = nullptr;
Script.ScriptVTable = nullptr;
Script.ScriptVTableBase = nullptr;
Script.SetupScriptVTable = other.Script.SetupScriptVTable;
@@ -210,6 +217,9 @@ ScriptingType::ScriptingType(const ScriptingType& other)
case ScriptingTypes::Enum:
break;
case ScriptingTypes::Interface:
Interface.SetupScriptVTable = other.Interface.SetupScriptVTable;
Interface.SetupScriptObjectVTable = other.Interface.SetupScriptObjectVTable;
Interface.GetInterfaceWrapper = other.Interface.GetInterfaceWrapper;
break;
default: ;
}
@@ -232,6 +242,8 @@ ScriptingType::ScriptingType(ScriptingType&& other)
Script.Spawn = other.Script.Spawn;
Script.VTable = other.Script.VTable;
other.Script.VTable = nullptr;
Script.InterfacesOffsets = other.Script.InterfacesOffsets;
other.Script.InterfacesOffsets = nullptr;
Script.ScriptVTable = other.Script.ScriptVTable;
other.Script.ScriptVTable = nullptr;
Script.ScriptVTableBase = other.Script.ScriptVTableBase;
@@ -257,6 +269,9 @@ ScriptingType::ScriptingType(ScriptingType&& other)
case ScriptingTypes::Enum:
break;
case ScriptingTypes::Interface:
Interface.SetupScriptVTable = other.Interface.SetupScriptVTable;
Interface.SetupScriptObjectVTable = other.Interface.SetupScriptObjectVTable;
Interface.GetInterfaceWrapper = other.Interface.GetInterfaceWrapper;
break;
default: ;
}
@@ -271,6 +286,7 @@ ScriptingType::~ScriptingType()
Delete(Script.DefaultInstance);
if (Script.VTable)
Platform::Free((byte*)Script.VTable - GetVTablePrefix());
Platform::Free(Script.InterfacesOffsets);
Platform::Free(Script.ScriptVTable);
Platform::Free(Script.ScriptVTableBase);
break;
@@ -332,6 +348,172 @@ const ScriptingType::InterfaceImplementation* ScriptingType::GetInterface(const
return nullptr;
}
void ScriptingType::SetupScriptVTable(ScriptingTypeHandle baseTypeHandle)
{
// Call setup for all class starting from the first native type (first that uses virtual calls will allocate table of a proper size, further base types will just add own methods)
for (ScriptingTypeHandle e = baseTypeHandle; e;)
{
const ScriptingType& eType = e.GetType();
if (eType.Script.SetupScriptVTable)
{
ASSERT(eType.ManagedClass);
eType.Script.SetupScriptVTable(eType.ManagedClass, Script.ScriptVTable, Script.ScriptVTableBase);
}
auto interfaces = eType.Interfaces;
if (interfaces && Script.ScriptVTable)
{
while (interfaces->InterfaceType)
{
auto& interfaceType = interfaces->InterfaceType->GetType();
if (interfaceType.Interface.SetupScriptVTable)
{
ASSERT(eType.ManagedClass);
const auto scriptOffset = interfaces->ScriptVTableOffset; // Shift the script vtable for the interface implementation start
Script.ScriptVTable += scriptOffset;
Script.ScriptVTableBase += scriptOffset;
interfaceType.Interface.SetupScriptVTable(eType.ManagedClass, Script.ScriptVTable, Script.ScriptVTableBase);
Script.ScriptVTable -= scriptOffset;
Script.ScriptVTableBase -= scriptOffset;
}
interfaces++;
}
}
e = eType.GetBaseType();
}
}
void ScriptingType::SetupScriptObjectVTable(void* object, ScriptingTypeHandle baseTypeHandle, int32 wrapperIndex)
{
// Analyze vtable size
void** vtable = *(void***)object;
const int32 prefixSize = GetVTablePrefix();
int32 entriesCount = 0;
while (vtable[entriesCount] && entriesCount < 200)
entriesCount++;
// Calculate total vtable size by adding all implemented interfaces that use virtual methods
const int32 size = entriesCount * sizeof(void*);
int32 totalSize = prefixSize + size;
int32 interfacesCount = 0;
for (ScriptingTypeHandle e = baseTypeHandle; e;)
{
const ScriptingType& eType = e.GetType();
auto interfaces = eType.Interfaces;
if (interfaces)
{
while (interfaces->InterfaceType)
{
auto& interfaceType = interfaces->InterfaceType->GetType();
if (interfaceType.Interface.SetupScriptObjectVTable)
{
void** vtableInterface = *(void***)((byte*)object + interfaces->VTableOffset);
int32 interfaceCount = 0;
while (vtableInterface[interfaceCount] && interfaceCount < 200)
interfaceCount++;
totalSize += prefixSize + interfaceCount * sizeof(void*);
interfacesCount++;
}
interfaces++;
}
}
e = eType.GetBaseType();
}
// Duplicate vtable
Script.VTable = (void**)((byte*)Platform::Allocate(totalSize, 16) + prefixSize);
Platform::MemoryCopy((byte*)Script.VTable - prefixSize, (byte*)vtable - prefixSize, prefixSize + size);
// Override vtable entries
if (interfacesCount)
Script.InterfacesOffsets = (uint16*)Platform::Allocate(interfacesCount * sizeof(uint16*), 16);
int32 interfaceOffset = size;
interfacesCount = 0;
for (ScriptingTypeHandle e = baseTypeHandle; e;)
{
const ScriptingType& eType = e.GetType();
if (eType.Script.SetupScriptObjectVTable)
{
// Override vtable entries for this class
eType.Script.SetupScriptObjectVTable(Script.ScriptVTable, Script.ScriptVTableBase, Script.VTable, entriesCount, wrapperIndex);
}
auto interfaces = eType.Interfaces;
if (interfaces)
{
while (interfaces->InterfaceType)
{
auto& interfaceType = interfaces->InterfaceType->GetType();
if (interfaceType.Interface.SetupScriptObjectVTable)
{
// Analyze interface vtable size
void** vtableInterface = *(void***)((byte*)object + interfaces->VTableOffset);
int32 interfaceCount = 0;
while (vtableInterface[interfaceCount] && interfaceCount < 200)
interfaceCount++;
const int32 interfaceSize = interfaceCount * sizeof(void*);
// Duplicate interface vtable
Platform::MemoryCopy((byte*)Script.VTable + interfaceOffset, (byte*)vtableInterface - prefixSize, prefixSize + interfaceSize);
// Override interface vtable entries
const auto scriptOffset = interfaces->ScriptVTableOffset;
const auto nativeOffset = interfaceOffset + prefixSize;
void** interfaceVTable = (void**)((byte*)Script.VTable + nativeOffset);
interfaceType.Interface.SetupScriptObjectVTable(Script.ScriptVTable + scriptOffset, Script.ScriptVTableBase + scriptOffset, interfaceVTable, interfaceCount, wrapperIndex);
Script.InterfacesOffsets[interfacesCount++] = (uint16)nativeOffset;
interfaceOffset += prefixSize + interfaceSize;
}
interfaces++;
}
}
e = eType.GetBaseType();
}
}
void ScriptingType::HackObjectVTable(void* object, ScriptingTypeHandle baseTypeHandle, int32 wrapperIndex)
{
if (!Script.ScriptVTable)
return;
if (!Script.VTable)
{
// Ensure to have valid Script VTable hacked
SetupScriptObjectVTable(object, baseTypeHandle, wrapperIndex);
}
// Override object vtable with hacked one that has calls to overriden scripting functions
*(void**)object = Script.VTable;
if (Script.InterfacesOffsets)
{
// Override vtables for interfaces
int32 interfacesCount = 0;
for (ScriptingTypeHandle e = baseTypeHandle; e;)
{
const ScriptingType& eType = e.GetType();
auto interfaces = eType.Interfaces;
if (interfaces)
{
while (interfaces->InterfaceType)
{
auto& interfaceType = interfaces->InterfaceType->GetType();
if (interfaceType.Interface.SetupScriptObjectVTable)
{
void** interfaceVTable = (void**)((byte*)Script.VTable + Script.InterfacesOffsets[interfacesCount++]);
*(void**)((byte*)object + interfaces->VTableOffset) = interfaceVTable;
interfacesCount++;
}
interfaces++;
}
}
e = eType.GetBaseType();
}
}
}
String ScriptingType::ToString() const
{
return String(Fullname.Get(), Fullname.Length());
@@ -382,12 +564,12 @@ ScriptingTypeInitializer::ScriptingTypeInitializer(BinaryModule* module, const S
module->TypeNameToTypeIndex[fullname] = TypeIndex;
}
ScriptingTypeInitializer::ScriptingTypeInitializer(BinaryModule* module, const StringAnsiView& fullname, ScriptingType::InitRuntimeHandler initRuntime, ScriptingTypeInitializer* baseType, const ScriptingType::InterfaceImplementation* interfaces)
ScriptingTypeInitializer::ScriptingTypeInitializer(BinaryModule* module, const StringAnsiView& fullname, ScriptingType::InitRuntimeHandler initRuntime, ScriptingType::SetupScriptVTableHandler setupScriptVTable, ScriptingType::SetupScriptObjectVTableHandler setupScriptObjectVTable, ScriptingType::GetInterfaceWrapper getInterfaceWrapper)
: ScriptingTypeHandle(module, module->Types.Count())
{
// Interface
module->Types.AddUninitialized();
new(module->Types.Get() + TypeIndex)ScriptingType(fullname, module, initRuntime, baseType, interfaces);
new(module->Types.Get() + TypeIndex)ScriptingType(fullname, module, initRuntime, setupScriptVTable, setupScriptObjectVTable, getInterfaceWrapper);
#if BUILD_DEBUG
if (module->TypeNameToTypeIndex.ContainsKey(fullname))
{
@@ -508,39 +690,7 @@ ScriptingObject* ManagedBinaryModule::ManagedObjectSpawn(const ScriptingObjectSp
}
// Beware! Hacking vtables incoming! Undefined behaviors exploits! Low-level programming!
// What's happening here?
// We create a custom vtable for the C# objects that use a native class object with virtual functions overrides.
// To make it easy to use in C++ we inject custom wrapper methods into C++ object vtable to call C# code from them.
// Because virtual member functions calls are C++ ABI and impl-defined this is quite hard. But works.
if (managedType.Script.ScriptVTable)
{
if (!managedType.Script.VTable)
{
// Duplicate vtable
void** vtable = *(void***)object;
const int32 prefixSize = GetVTablePrefix();
int32 entriesCount = 0;
while (vtable[entriesCount] && entriesCount < 200)
entriesCount++;
const int32 size = entriesCount * sizeof(void*);
managedType.Script.VTable = (void**)((byte*)Platform::Allocate(prefixSize + size, 16) + prefixSize);
Platform::MemoryCopy((byte*)managedType.Script.VTable - prefixSize, (byte*)vtable - prefixSize, prefixSize + size);
// Override vtable entries by the class
for (ScriptingTypeHandle e = nativeTypeHandle; e;)
{
const ScriptingType& eType = e.GetType();
if (eType.Script.SetupScriptObjectVTable)
{
eType.Script.SetupScriptObjectVTable(managedType.Script.ScriptVTable, managedType.Script.ScriptVTableBase, managedType.Script.VTable, entriesCount, 0);
}
e = eType.GetBaseType();
}
}
// Override object vtable with hacked one that has C# functions calls
*(void**)object = managedType.Script.VTable;
}
managedType.HackObjectVTable(object, nativeTypeHandle, 0);
// Mark as managed type
object->Flags |= ObjectFlags::IsManagedType;
@@ -618,6 +768,20 @@ ManagedBinaryModule* ManagedBinaryModule::FindModule(MonoClass* klass)
return module;
}
ScriptingTypeHandle ManagedBinaryModule::FindType(MonoClass* klass)
{
auto typeModule = FindModule(klass);
if (typeModule)
{
int32 typeIndex;
if (typeModule->ClassToTypeIndex.TryGet(klass, typeIndex))
{
return ScriptingTypeHandle(typeModule, typeIndex);
}
}
return ScriptingTypeHandle();
}
void ManagedBinaryModule::OnLoading(MAssembly* assembly)
{
PROFILE_CPU();
@@ -672,13 +836,9 @@ void ManagedBinaryModule::OnLoaded(MAssembly* assembly)
for (auto i = classes.Begin(); i.IsNotEnd(); ++i)
{
MClass* mclass = i->Value;
const MString& typeName = mclass->GetFullName();
// Check if C# class inherits from C++ object class it has no C++ representation
if (
TypeNameToTypeIndex.Find(typeName) != TypeNameToTypeIndex.End() ||
mclass->IsStatic() ||
mclass->IsAbstract() ||
if (mclass->IsStatic() ||
mclass->IsInterface() ||
!mclass->IsSubClassOf(scriptingObjectType)
)
@@ -686,96 +846,151 @@ void ManagedBinaryModule::OnLoaded(MAssembly* assembly)
continue;
}
// Find first native base C++ class of this C# class
MClass* baseClass = mclass->GetBaseClass();
ScriptingTypeHandle nativeType;
while (baseClass)
InitType(mclass);
}
}
}
void ManagedBinaryModule::InitType(MClass* mclass)
{
// Skip if already initialized
const MString& typeName = mclass->GetFullName();
if (TypeNameToTypeIndex.ContainsKey(typeName))
return;
// Find first native base C++ class of this C# class
MClass* baseClass = nullptr;
MonoClass* baseKlass = mono_class_get_parent(mclass->GetNative());
MonoImage* baseKlassImage = mono_class_get_image(baseKlass);
ScriptingTypeHandle baseType;
auto& modules = GetModules();
for (int32 i = 0; i < modules.Count(); i++)
{
auto e = dynamic_cast<ManagedBinaryModule*>(modules[i]);
if (e && e->Assembly->GetMonoImage() == baseKlassImage)
{
baseType.Module = e;
baseClass = e->Assembly->GetClass(baseKlass);
break;
}
}
if (!baseClass)
{
LOG(Error, "Missing base class for managed class {0} from assembly {1}.", String(typeName), Assembly->ToString());
return;
}
if (baseType.Module == this)
InitType(baseClass); // Ensure base is initialized before
baseType.Module->TypeNameToTypeIndex.TryGet(baseClass->GetFullName(), *(int32*)&baseType.TypeIndex);
if (!baseType)
{
LOG(Error, "Missing base class for managed class {0} from assembly {1}.", String(typeName), Assembly->ToString());
return;
}
ScriptingTypeHandle nativeType = baseType;
while (true)
{
auto& type = nativeType.GetType();
if (type.Script.Spawn != &ManagedObjectSpawn)
break;
nativeType = type.GetBaseType();
if (!nativeType)
{
LOG(Error, "Missing base class for managed class {0} from assembly {1}.", String(typeName), Assembly->ToString());
return;
}
}
// Scripting Type has Fullname span pointing to the string in memory (usually static data) so store the name in assembly
char* typeNameData = (char*)Allocator::Allocate(typeName.Length() + 1);
Platform::MemoryCopy(typeNameData, typeName.Get(), typeName.Length());
typeNameData[typeName.Length()] = 0;
_managedMemoryBlocks.Add(typeNameData);
// Initialize scripting interfaces implemented in C#
MonoClass* interfaceKlass;
void* interfaceIt = nullptr;
int32 interfacesCount = 0;
MonoClass* klass = mclass->GetNative();
while (interfaceKlass = mono_class_get_interfaces(klass, &interfaceIt))
{
const ScriptingTypeHandle interfaceType = FindType(interfaceKlass);
if (interfaceType)
interfacesCount++;
}
ScriptingType::InterfaceImplementation* interfaces = nullptr;
if (interfacesCount != 0)
{
interfaces = (ScriptingType::InterfaceImplementation*)Allocator::Allocate((interfacesCount + 1) * sizeof(ScriptingType::InterfaceImplementation));
interfacesCount = 0;
interfaceIt = nullptr;
while (interfaceKlass = mono_class_get_interfaces(klass, &interfaceIt))
{
const ScriptingTypeHandle interfaceTypeHandle = FindType(interfaceKlass);
if (!interfaceTypeHandle)
continue;
auto& interface = interfaces[interfacesCount++];
auto ptr = (ScriptingTypeHandle*)Allocator::Allocate(sizeof(ScriptingTypeHandle));
*ptr = interfaceTypeHandle;
_managedMemoryBlocks.Add(ptr);
interface.InterfaceType = ptr;
interface.VTableOffset = 0;
interface.ScriptVTableOffset = 0;
interface.IsNative = false;
}
Platform::MemoryClear(interfaces + interfacesCount, sizeof(ScriptingType::InterfaceImplementation));
_managedMemoryBlocks.Add(interfaces);
}
// Create scripting type descriptor for managed-only type based on the native base class
const int32 typeIndex = Types.Count();
Types.AddUninitialized();
new(Types.Get() + Types.Count() - 1)ScriptingType(typeName, this, baseType.GetType().Size, ScriptingType::DefaultInitRuntime, ManagedObjectSpawn, baseType, nullptr, nullptr, interfaces);
TypeNameToTypeIndex[typeName] = typeIndex;
auto& type = Types[typeIndex];
type.ManagedClass = mclass;
// Register Mono class
ASSERT(!ClassToTypeIndex.ContainsKey(klass));
ClassToTypeIndex[klass] = typeIndex;
// Create managed vtable for this class (build out of the wrapper C++ methods that call C# methods)
type.SetupScriptVTable(nativeType);
MMethod** scriptVTable = (MMethod**)type.Script.ScriptVTable;
while (scriptVTable && *scriptVTable)
{
const MMethod* referenceMethod = *scriptVTable;
// Find that method overriden in C# class (the current or one of the base classes in C#)
MMethod* method = ::FindMethod(mclass, referenceMethod);
if (method == nullptr)
{
// Check base classes (skip native class)
baseClass = mclass->GetBaseClass();
MClass* nativeBaseClass = nativeType.GetType().ManagedClass;
while (baseClass && baseClass != nativeBaseClass && method == nullptr)
{
int32 typeIndex;
BinaryModule* baseClassModule = GetModule(baseClass->GetAssembly());
ASSERT(baseClassModule);
if (baseClassModule->TypeNameToTypeIndex.TryGet(baseClass->GetFullName(), typeIndex))
method = ::FindMethod(baseClass, referenceMethod);
// Special case if method was found but the base class uses generic arguments
if (method && baseClass->IsGeneric())
{
nativeType = ScriptingTypeHandle(baseClassModule, typeIndex);
if (nativeType.GetType().Script.Spawn != &ManagedObjectSpawn)
break;
// TODO: encapsulate it into MClass to support inflated methods
auto parentClass = mono_class_get_parent(mclass->GetNative());
auto parentMethod = mono_class_get_method_from_name(parentClass, referenceMethod->GetName().Get(), 0);
auto inflatedMethod = mono_class_inflate_generic_method(parentMethod, nullptr);
method = New<MMethod>(inflatedMethod, baseClass);
}
baseClass = baseClass->GetBaseClass();
}
if (!nativeType)
{
LOG(Error, "Missing native base class for managed class {0} from assembly {1}.", String(typeName), assembly->ToString());
continue;
}
// Scripting Type has Fullname span pointing to the string in memory (usually static data) so store the name in assembly
char* typeNameData = (char*)Allocator::Allocate(typeName.Length() + 1);
Platform::MemoryCopy(typeNameData, typeName.Get(), typeName.Length());
typeNameData[typeName.Length()] = 0;
_managedTypesNames.Add(typeNameData);
// Create scripting type descriptor for managed-only type based on the native base class
const int32 typeIndex = Types.Count();
Types.AddUninitialized();
new(Types.Get() + Types.Count() - 1)ScriptingType(typeName, this, nativeType.GetType().Size, ScriptingType::DefaultInitRuntime, ManagedObjectSpawn, nativeType);
TypeNameToTypeIndex[typeName] = typeIndex;
auto& type = Types[typeIndex];
type.ManagedClass = mclass;
// Register Mono class
MonoClass* klass = mclass->GetNative();
ASSERT(!ClassToTypeIndex.ContainsKey(klass));
ClassToTypeIndex[klass] = typeIndex;
// Create managed vtable for this class (build out of the wrapper C++ methods that call C# methods)
// Call setup for all class starting from the first native type (first that uses virtual calls will allocate table of a proper size, further base types will just add own methods)
for (ScriptingTypeHandle e = nativeType; e;)
{
const ScriptingType& eType = e.GetType();
if (eType.Script.SetupScriptVTable)
{
ASSERT(eType.ManagedClass);
eType.Script.SetupScriptVTable(eType.ManagedClass, type.Script.ScriptVTable, type.Script.ScriptVTableBase);
}
e = eType.GetBaseType();
}
MMethod** scriptVTable = (MMethod**)type.Script.ScriptVTable;
while (scriptVTable && *scriptVTable)
{
const MMethod* referenceMethod = *scriptVTable;
// Find that method overriden in C# class (the current or one of the base classes in C#)
MMethod* method = ::FindMethod(mclass, referenceMethod);
if (method == nullptr)
{
// Check base classes (skip native class)
baseClass = mclass->GetBaseClass();
MClass* nativeBaseClass = nativeType.GetType().ManagedClass;
while (baseClass && baseClass != nativeBaseClass && method == nullptr)
{
method = ::FindMethod(baseClass, referenceMethod);
// Special case if method was found but the base class uses generic arguments
if (method && baseClass->IsGeneric())
{
// TODO: encapsulate it into MClass to support inflated methods
auto parentClass = mono_class_get_parent(mclass->GetNative());
auto parentMethod = mono_class_get_method_from_name(parentClass, referenceMethod->GetName().Get(), 0);
auto inflatedMethod = mono_class_inflate_generic_method(parentMethod, nullptr);
method = New<MMethod>(inflatedMethod, baseClass);
}
baseClass = baseClass->GetBaseClass();
}
}
// Set the method to call (null entry marks unused entries that won't use C# wrapper calls)
*scriptVTable = method;
// Move to the next entry (table is null terminated)
scriptVTable++;
}
}
// Set the method to call (null entry marks unused entries that won't use C# wrapper calls)
*scriptVTable = method;
// Move to the next entry (table is null terminated)
scriptVTable++;
}
}
@@ -791,9 +1006,9 @@ void ManagedBinaryModule::OnUnloading(MAssembly* assembly)
TypeNameToTypeIndex.Remove(typeName);
}
Types.Resize(_firstManagedTypeIndex);
for (int32 i = 0; i < _managedTypesNames.Count(); i++)
Allocator::Free(_managedTypesNames[i]);
_managedTypesNames.Clear();
for (int32 i = 0; i < _managedMemoryBlocks.Count(); i++)
Allocator::Free(_managedMemoryBlocks[i]);
_managedMemoryBlocks.Clear();
// Clear managed types information
for (ScriptingType& type : Types)

View File

@@ -261,7 +261,7 @@ public:
private:
int32 _firstManagedTypeIndex;
Array<char*> _managedTypesNames;
Array<void*> _managedMemoryBlocks;
public:
@@ -298,11 +298,13 @@ public:
static ScriptingObject* ManagedObjectSpawn(const ScriptingObjectSpawnParams& params);
static MMethod* FindMethod(MClass* mclass, const ScriptingTypeMethodSignature& signature);
static ManagedBinaryModule* FindModule(MonoClass* klass);
static ScriptingTypeHandle FindType(MonoClass* klass);
private:
void OnLoading(MAssembly* assembly);
void OnLoaded(MAssembly* assembly);
void InitType(MClass* mclass);
void OnUnloading(MAssembly* assembly);
public:

View File

@@ -58,7 +58,6 @@ public:
/// <summary>
/// Gets the parent assembly.
/// </summary>
/// <returns>The assembly.</returns>
const MAssembly* GetAssembly() const
{
return _assembly;
@@ -67,7 +66,6 @@ public:
/// <summary>
/// Gets the full name of the class (namespace and typename).
/// </summary>
/// <returns>The fullname.</returns>
FORCE_INLINE const MString& GetFullName() const
{
return _fullname;
@@ -78,7 +76,6 @@ public:
/// <summary>
/// Gets the Mono class handle.
/// </summary>
/// <returns>The Mono class.</returns>
MonoClass* GetNative() const;
#endif
@@ -86,7 +83,6 @@ public:
/// <summary>
/// Gets class visibility
/// </summary>
/// <returns>Returns visibility struct.</returns>
FORCE_INLINE MVisibility GetVisibility() const
{
return _visibility;
@@ -95,7 +91,6 @@ public:
/// <summary>
/// Gets if class is static
/// </summary>
/// <returns>Returns true if class is static, otherwise false.</returns>
FORCE_INLINE bool IsStatic() const
{
return _isStatic != 0;
@@ -104,7 +99,6 @@ public:
/// <summary>
/// Gets if class is abstract
/// </summary>
/// <returns>Returns true if class is static, otherwise false.</returns>
FORCE_INLINE bool IsAbstract() const
{
return _isAbstract != 0;
@@ -113,7 +107,6 @@ public:
/// <summary>
/// Gets if class is sealed
/// </summary>
/// <returns>Returns true if class is static, otherwise false.</returns>
FORCE_INLINE bool IsSealed() const
{
return _isSealed != 0;
@@ -122,7 +115,6 @@ public:
/// <summary>
/// Gets if class is interface
/// </summary>
/// <returns>Returns true if class is static, otherwise false.</returns>
FORCE_INLINE bool IsInterface() const
{
return _isInterface != 0;
@@ -131,19 +123,16 @@ public:
/// <summary>
/// Gets if class is generic
/// </summary>
/// <returns>Returns true if class is generic type, otherwise false.</returns>
bool IsGeneric() const;
/// <summary>
/// Gets the class type.
/// </summary>
/// <returns>The type.</returns>
MType GetType() const;
/// <summary>
/// Returns the base class of this class. Null if this class has no base.
/// </summary>
/// <returns>The base class.</returns>
MClass* GetBaseClass() const;
/// <summary>
@@ -170,7 +159,6 @@ public:
/// <summary>
/// Returns the size of an instance of this class, in bytes.
/// </summary>
/// <returns>The instance size (in bytes).</returns>
uint32 GetInstanceSize() const;
public:

View File

@@ -210,6 +210,17 @@ namespace FlaxEngine
return GetUnmanagedPtr(reference.Get<Object>());
}
/// <summary>
/// Gets the pointer to the native interface implementation. Handles null object reference or invalid cast (returns zero).
/// </summary>
/// <param name="obj">The object.</param>
/// <param name="type">The interface type.</param>
/// <returns>The native interface pointer.</returns>
public static IntPtr GetUnmanagedInterface(object obj, Type type)
{
return obj is Object o ? Internal_GetUnmanagedInterface(o.__unmanagedPtr, type) : IntPtr.Zero;
}
/// <inheritdoc />
public override int GetHashCode()
{
@@ -245,6 +256,9 @@ namespace FlaxEngine
[MethodImpl(MethodImplOptions.InternalCall)]
internal static extern void Internal_ChangeID(IntPtr obj, ref Guid id);
[MethodImpl(MethodImplOptions.InternalCall)]
internal static extern IntPtr Internal_GetUnmanagedInterface(IntPtr obj, Type type);
#endregion
}
}

View File

@@ -9,6 +9,7 @@
#include "Engine/Utilities/StringConverter.h"
#include "Engine/Content/Asset.h"
#include "Engine/Content/Content.h"
#include "Engine/Profiler/ProfilerCPU.h"
#include "ManagedCLR/MAssembly.h"
#include "ManagedCLR/MClass.h"
#include "ManagedCLR/MUtils.h"
@@ -23,6 +24,9 @@
#define ScriptingObject_unmanagedPtr "__unmanagedPtr"
#define ScriptingObject_id "__internalId"
// TODO: don't leak memory (use some kind of late manual GC for those wrapper objects)
Dictionary<ScriptingObject*, void*> ScriptingObjectsInterfaceWrappers;
ScriptingObject::ScriptingObject(const SpawnParams& params)
: _gcHandle(0)
, _type(params.Type)
@@ -71,6 +75,51 @@ MClass* ScriptingObject::GetClass() const
return _type ? _type.GetType().ManagedClass : nullptr;
}
ScriptingObject* ScriptingObject::FromInterface(void* interfaceObj, ScriptingTypeHandle& interfaceType)
{
if (!interfaceObj || !interfaceType)
return nullptr;
PROFILE_CPU();
// Find the type which implements this interface and has the same vtable as interface object
// TODO: implement vtableInterface->type hashmap caching in Scripting service to optimize sequential interface casts
auto& modules = BinaryModule::GetModules();
for (auto module : modules)
{
for (auto& type : module->Types)
{
if (type.Type != ScriptingTypes::Script)
continue;
auto interfaceImpl = type.GetInterface(interfaceType);
if (interfaceImpl && interfaceImpl->IsNative)
{
ScriptingObject* predictedObj = (ScriptingObject*)((byte*)interfaceObj - interfaceImpl->VTableOffset);
void* predictedVTable = *(void***)predictedObj;
void* vtable = type.Script.VTable;
if (!vtable && type.GetDefaultInstance())
{
// Use vtable from default instance of this type
vtable = *(void***)type.GetDefaultInstance();
}
if (vtable == predictedVTable)
{
ASSERT(predictedObj->GetType().GetInterface(interfaceType));
return predictedObj;
}
}
}
}
// Special case for interface wrapper object
for (auto& e : ScriptingObjectsInterfaceWrappers)
{
if (e.Value == interfaceObj)
return e.Key;
}
return nullptr;
}
ScriptingObject* ScriptingObject::ToNative(MonoObject* obj)
{
ScriptingObject* ptr = nullptr;
@@ -575,6 +624,37 @@ public:
obj->ChangeID(*id);
}
static void* GetUnmanagedInterface(ScriptingObject* obj, MonoReflectionType* type)
{
if (obj && type)
{
auto typeClass = MUtils::GetClass(type);
const ScriptingTypeHandle interfaceType = ManagedBinaryModule::FindType(typeClass);
if (interfaceType)
{
const ScriptingType& objectType = obj->GetType();
const ScriptingType::InterfaceImplementation* interface = objectType.GetInterface(interfaceType);
if (interface && interface->IsNative)
{
// Native interface so just offset pointer to the interface vtable start
return (byte*)obj + interface->VTableOffset;
}
if (interface)
{
// Interface implemented in scripting (eg. C# class inherits C++ interface)
void* result;
if (!ScriptingObjectsInterfaceWrappers.TryGet(obj, result))
{
result = interfaceType.GetType().Interface.GetInterfaceWrapper(obj);
ScriptingObjectsInterfaceWrappers.Add(obj, result);
}
return result;
}
}
}
return nullptr;
}
static void InitRuntime()
{
ADD_INTERNAL_CALL("FlaxEngine.Object::Internal_Create1", &Create1);
@@ -586,6 +666,7 @@ public:
ADD_INTERNAL_CALL("FlaxEngine.Object::Internal_FindObject", &FindObject);
ADD_INTERNAL_CALL("FlaxEngine.Object::Internal_TryFindObject", &TryFindObject);
ADD_INTERNAL_CALL("FlaxEngine.Object::Internal_ChangeID", &ChangeID);
ADD_INTERNAL_CALL("FlaxEngine.Object::Internal_GetUnmanagedInterface", &GetUnmanagedInterface);
}
static ScriptingObject* Spawn(const ScriptingObjectSpawnParams& params)

View File

@@ -108,6 +108,8 @@ public:
public:
// Tries to cast native interface object to scripting object instance. Returns null if fails.
static ScriptingObject* FromInterface(void* interfaceObj, ScriptingTypeHandle& interfaceType);
static ScriptingObject* ToNative(MonoObject* obj);
static MonoObject* ToManaged(ScriptingObject* obj)

View File

@@ -117,6 +117,7 @@ struct FLAXENGINE_API ScriptingType
typedef void (*Unbox)(void* ptr, MonoObject* managed);
typedef void (*GetField)(void* ptr, const String& name, Variant& value);
typedef void (*SetField)(void* ptr, const String& name, const Variant& value);
typedef void* (*GetInterfaceWrapper)(ScriptingObject* obj);
struct InterfaceImplementation
{
@@ -125,6 +126,12 @@ struct FLAXENGINE_API ScriptingType
// The offset (in bytes) from the object pointer to the interface implementation. Used for casting object to the interface.
int16 VTableOffset;
// The offset (in entries) from the script vtable to the interface implementation. Used for initializing interface virtual script methods.
int16 ScriptVTableOffset;
// True if interface implementation is native (inside C++ object), otherwise it's injected at scripting level (cannot call interface directly).
bool IsNative;
};
/// <summary>
@@ -186,6 +193,11 @@ struct FLAXENGINE_API ScriptingType
/// </summary>
void** VTable;
/// <summary>
/// List of offsets from native methods VTable for each interface (with virtual methods). Null if not using interfaces with method overrides.
/// </summary>
uint16* InterfacesOffsets;
/// <summary>
/// The script methods VTable used by the wrapper functions attached to native object vtable. Cached to improve C#/VisualScript invocation performance.
/// </summary>
@@ -244,6 +256,13 @@ struct FLAXENGINE_API ScriptingType
// Class destructor method pointer
Dtor Dtor;
} Class;
struct
{
SetupScriptVTableHandler SetupScriptVTable;
SetupScriptObjectVTableHandler SetupScriptObjectVTable;
GetInterfaceWrapper GetInterfaceWrapper;
} Interface;
};
ScriptingType();
@@ -251,7 +270,7 @@ struct FLAXENGINE_API ScriptingType
ScriptingType(const StringAnsiView& fullname, BinaryModule* module, int32 size, InitRuntimeHandler initRuntime = DefaultInitRuntime, SpawnHandler spawn = DefaultSpawn, ScriptingTypeInitializer* baseType = nullptr, SetupScriptVTableHandler setupScriptVTable = nullptr, SetupScriptObjectVTableHandler setupScriptObjectVTable = nullptr, const InterfaceImplementation* interfaces = nullptr);
ScriptingType(const StringAnsiView& fullname, BinaryModule* module, int32 size, InitRuntimeHandler initRuntime, Ctor ctor, Dtor dtor, ScriptingTypeInitializer* baseType, const InterfaceImplementation* interfaces = nullptr);
ScriptingType(const StringAnsiView& fullname, BinaryModule* module, int32 size, InitRuntimeHandler initRuntime, Ctor ctor, Dtor dtor, Copy copy, Box box, Unbox unbox, GetField getField, SetField setField, ScriptingTypeInitializer* baseType, const InterfaceImplementation* interfaces = nullptr);
ScriptingType(const StringAnsiView& fullname, BinaryModule* module, InitRuntimeHandler initRuntime, ScriptingTypeInitializer* baseType, const InterfaceImplementation* interfaces = nullptr);
ScriptingType(const StringAnsiView& fullname, BinaryModule* module, InitRuntimeHandler initRuntime, SetupScriptVTableHandler setupScriptVTable, SetupScriptObjectVTableHandler setupScriptObjectVTable, GetInterfaceWrapper getInterfaceWrapper);
ScriptingType(const ScriptingType& other);
ScriptingType(ScriptingType&& other);
ScriptingType& operator=(ScriptingType&& other) = delete;
@@ -292,6 +311,9 @@ struct FLAXENGINE_API ScriptingType
/// </summary>
const InterfaceImplementation* GetInterface(const ScriptingTypeHandle& interfaceType) const;
void SetupScriptVTable(ScriptingTypeHandle baseTypeHandle);
void SetupScriptObjectVTable(void* object, ScriptingTypeHandle baseTypeHandle, int32 wrapperIndex);
void HackObjectVTable(void* object, ScriptingTypeHandle baseTypeHandle, int32 wrapperIndex);
String ToString() const;
};
@@ -303,7 +325,7 @@ struct FLAXENGINE_API ScriptingTypeInitializer : ScriptingTypeHandle
ScriptingTypeInitializer(BinaryModule* module, const StringAnsiView& fullname, int32 size, ScriptingType::InitRuntimeHandler initRuntime = ScriptingType::DefaultInitRuntime, ScriptingType::SpawnHandler spawn = ScriptingType::DefaultSpawn, ScriptingTypeInitializer* baseType = nullptr, ScriptingType::SetupScriptVTableHandler setupScriptVTable = nullptr, ScriptingType::SetupScriptObjectVTableHandler setupScriptObjectVTable = nullptr, const ScriptingType::InterfaceImplementation* interfaces = nullptr);
ScriptingTypeInitializer(BinaryModule* module, const StringAnsiView& fullname, int32 size, ScriptingType::InitRuntimeHandler initRuntime, ScriptingType::Ctor ctor, ScriptingType::Dtor dtor, ScriptingTypeInitializer* baseType = nullptr, const ScriptingType::InterfaceImplementation* interfaces = nullptr);
ScriptingTypeInitializer(BinaryModule* module, const StringAnsiView& fullname, int32 size, ScriptingType::InitRuntimeHandler initRuntime, ScriptingType::Ctor ctor, ScriptingType::Dtor dtor, ScriptingType::Copy copy, ScriptingType::Box box, ScriptingType::Unbox unbox, ScriptingType::GetField getField, ScriptingType::SetField setField, ScriptingTypeInitializer* baseType = nullptr, const ScriptingType::InterfaceImplementation* interfaces = nullptr);
ScriptingTypeInitializer(BinaryModule* module, const StringAnsiView& fullname, ScriptingType::InitRuntimeHandler initRuntime, ScriptingTypeInitializer* baseType = nullptr, const ScriptingType::InterfaceImplementation* interfaces = nullptr);
ScriptingTypeInitializer(BinaryModule* module, const StringAnsiView& fullname, ScriptingType::InitRuntimeHandler initRuntime, ScriptingType::SetupScriptVTableHandler setupScriptVTable, ScriptingType::SetupScriptObjectVTableHandler setupScriptObjectVTable, ScriptingType::GetInterfaceWrapper getInterfaceWrapper);
};
/// <summary>
@@ -317,7 +339,7 @@ struct ScriptingObjectSpawnParams
Guid ID;
/// <summary>
/// The object type handle (script class might not be loaded yet).
/// The object type handle.
/// </summary>
const ScriptingTypeHandle Type;