Add C++ lambda support for Function<> and Delegate<>

This commit is contained in:
Wojtek Figat
2021-10-05 13:07:38 +02:00
parent 32794f89c7
commit cabd06dd53

View File

@@ -5,7 +5,7 @@
#include "Engine/Core/Memory/Allocation.h"
/// <summary>
/// The function object.
/// The function object that supports binding static, member and lambda functions.
/// </summary>
template<typename ReturnType, typename... Params>
class Function<ReturnType(Params ...)>
@@ -45,9 +45,28 @@ private:
return (reinterpret_cast<T*>(callee)->*Method)(Forward<Params>(params)...);
}
struct Lambda
{
int64 Refs;
void (*Dtor)(void*);
};
void* _callee;
StubSignature _function;
Lambda* _lambda;
FORCE_INLINE void LambdaCtor() const
{
Platform::InterlockedIncrement((int64 volatile*)&_lambda->Refs);
}
FORCE_INLINE void LambdaDtor()
{
if (Platform::InterlockedDecrement(&_lambda->Refs) == 0)
{
((Lambda*)_lambda)->Dtor(_callee);
Allocator::Free(_lambda);
}
}
public:
/// <summary>
@@ -57,6 +76,7 @@ public:
{
_callee = nullptr;
_function = nullptr;
_lambda = nullptr;
}
/// <summary>
@@ -64,9 +84,32 @@ public:
/// </summary>
Function(Signature method)
{
ASSERT(method);
_callee = (void*)method;
_function = &StaticPointerMethodStub;
_lambda = nullptr;
}
Function(const Function& other)
: _callee(other._callee)
, _function(other._function)
, _lambda(other._lambda)
{
if (_lambda)
LambdaCtor();
}
Function(Function&& other) noexcept
: _callee(other._callee)
, _function(other._function)
, _lambda(other._lambda)
{
other._lambda = nullptr;
}
~Function()
{
if (_lambda)
LambdaDtor();
}
public:
@@ -77,8 +120,11 @@ public:
template<ReturnType (*Method)(Params ...)>
void Bind()
{
if (_lambda)
LambdaDtor();
_callee = nullptr;
_function = &StaticMethodStub<Method>;
_lambda = nullptr;
}
/// <summary>
@@ -88,8 +134,11 @@ public:
template<class T, ReturnType(T::*Method)(Params ...)>
void Bind(T* callee)
{
if (_lambda)
LambdaDtor();
_callee = callee;
_function = &ClassMethodStub<T, Method>;
_lambda = nullptr;
}
/// <summary>
@@ -98,8 +147,28 @@ public:
/// <param name="method">The method.</param>
void Bind(Signature method)
{
if (_lambda)
LambdaDtor();
_callee = (void*)method;
_function = &StaticPointerMethodStub;
_lambda = nullptr;
}
/// <summary>
/// Binds a lambda.
/// </summary>
/// <param name="lambda">The lambda.</param>
template<typename T>
void Bind(const T& lambda)
{
if (_lambda)
LambdaDtor();
_lambda = (Lambda*)Allocator::Allocate(sizeof(Lambda) + sizeof(T));
_lambda->Refs = 1;
_lambda->Dtor = [](void* callee) -> void { static_cast<T*>(callee)->~T(); };
_function = [](void* callee, Params ... params) -> ReturnType { return (*static_cast<T*>(callee))(Forward<Params>(params)...); };
_callee = (byte*)_lambda + sizeof(Lambda);
new(_callee) T(lambda);
}
/// <summary>
@@ -107,8 +176,11 @@ public:
/// </summary>
void Unbind()
{
if (_lambda)
LambdaDtor();
_callee = nullptr;
_function = nullptr;
_lambda = nullptr;
}
public:
@@ -116,7 +188,6 @@ public:
/// <summary>
/// Returns true if any function has been binded.
/// </summary>
/// <returns>True if any function has been binded, otherwise false.</returns>
FORCE_INLINE bool IsBinded() const
{
return _function != nullptr;
@@ -144,6 +215,31 @@ public:
return _function(_callee, Forward<Params>(params)...);
}
Function& operator=(const Function& other)
{
if (this == &other)
return *this;
_callee = other._callee;
_function = other._function;
_lambda = other._lambda;
if (_lambda)
LambdaCtor();
return *this;
}
Function& operator=(Function&& other) noexcept
{
if (this == &other)
return *this;
_callee = other._callee;
_function = other._function;
_lambda = other._lambda;
other._callee = nullptr;
other._function = nullptr;
other._lambda = nullptr;
return *this;
}
FORCE_INLINE bool operator==(const Function& other) const
{
return _function == other._function && _callee == other._callee;
@@ -189,7 +285,17 @@ public:
~Delegate()
{
Allocator::Free((void*)_ptr);
auto ptr = (FunctionType*)_ptr;
if (ptr)
{
while (_size--)
{
if (ptr->_lambda)
ptr->LambdaDtor();
++ptr;
}
Allocator::Free((void*)_ptr);
}
}
public:
@@ -212,7 +318,6 @@ public:
template<class T, void(T::*Method)(Params ...)>
void Bind(T* callee)
{
ASSERT(callee);
FunctionType f;
f.template Bind<T, Method>(callee);
Bind(f);
@@ -228,6 +333,18 @@ public:
Bind(f);
}
/// <summary>
/// Binds a lambda.
/// </summary>
/// <param name="lambda">The lambda.</param>
template<typename T>
void Bind(const T& lambda)
{
FunctionType f;
f.template Bind<T>(lambda);
Bind(f);
}
/// <summary>
/// Binds a function.
/// </summary>
@@ -243,7 +360,10 @@ public:
{
if (Platform::InterlockedCompareExchange((intptr volatile*)&bindings[i]._function, (intptr)f._function, 0) == 0)
{
Platform::AtomicStore((intptr volatile*)&bindings[i]._callee, (intptr)f._callee);
bindings[i]._callee = f._callee;
bindings[i]._lambda = f._lambda;
if (f._lambda)
f.LambdaCtor();
return;
}
}
@@ -293,7 +413,6 @@ public:
template<class T, void(T::*Method)(Params ...)>
void Unbind(T* callee)
{
ASSERT(callee);
FunctionType f;
f.template Bind<T, Method>(callee);
Unbind(f);
@@ -322,6 +441,11 @@ public:
{
if (Platform::AtomicRead((intptr volatile*)&bindings[i]._callee) == (intptr)f._callee && Platform::AtomicRead((intptr volatile*)&bindings[i]._function) == (intptr)f._function)
{
if (bindings[i]._lambda)
{
bindings[i].LambdaDtor();
bindings[i]._lambda = nullptr;
}
Platform::AtomicStore((intptr volatile*)&bindings[i]._callee, 0);
Platform::AtomicStore((intptr volatile*)&bindings[i]._function, 0);
break;
@@ -343,6 +467,11 @@ public:
FunctionType* bindings = (FunctionType*)Platform::AtomicRead(&_ptr);
for (intptr i = 0; i < size; i++)
{
if (bindings[i]._lambda)
{
bindings[i].LambdaDtor();
bindings[i]._lambda = nullptr;
}
Platform::AtomicStore((intptr volatile*)&bindings[i]._function, 0);
Platform::AtomicStore((intptr volatile*)&bindings[i]._callee, 0);
}
@@ -398,6 +527,9 @@ public:
if (buffer[count]._function != nullptr)
{
buffer[count]._callee = (void*)Platform::AtomicRead((intptr volatile*)&bindings[i]._callee);
buffer[count]._lambda = (typename FunctionType::Lambda*)Platform::AtomicRead((intptr volatile*)&bindings[i]._lambda);
if (buffer[count]._lambda)
buffer[count].LambdaCtor();
count++;
}
}
@@ -414,12 +546,11 @@ public:
FunctionType* bindings = (FunctionType*)Platform::AtomicRead((intptr volatile*)&_ptr);
for (intptr i = 0; i < size; i++)
{
FunctionType f;
f._function = (StubSignature)Platform::AtomicRead((intptr volatile*)&bindings[i]._function);
if (f._function != nullptr)
auto function = (StubSignature)Platform::AtomicRead((intptr volatile*)&bindings[i]._function);
if (function != nullptr)
{
f._callee = (void*)Platform::AtomicRead((intptr volatile*)&bindings[i]._callee);
f._function(f._callee, Forward<Params>(params)...);
auto callee = (void*)Platform::AtomicRead((intptr volatile*)&bindings[i]._callee);
function(callee, Forward<Params>(params)...);
}
}
}