Fix ThreadLocal to use atomic operations and prevent rare race conditions
This commit is contained in:
@@ -10,19 +10,18 @@
|
||||
#define THREAD_LOCAL_MAX_CAPACITY 16
|
||||
|
||||
/// <summary>
|
||||
/// Per thread local variable
|
||||
/// Per-thread local variable storage.
|
||||
/// Implemented using atomic with per-thread storage indexed via thread id hashing.
|
||||
/// ForConsider using 'THREADLOCAL' define before the variable instead.
|
||||
/// </summary>
|
||||
template<typename T, int32 MaxThreads = THREAD_LOCAL_MAX_CAPACITY, bool ClearMemory = true>
|
||||
class ThreadLocal
|
||||
{
|
||||
// Note: this is kind of weak-implementation. We don't want to use locks/semaphores.
|
||||
// For better performance use 'THREADLOCAL' define before the variable
|
||||
|
||||
protected:
|
||||
|
||||
struct Bucket
|
||||
{
|
||||
uint64 ThreadID;
|
||||
volatile int64 ThreadID;
|
||||
T Value;
|
||||
};
|
||||
|
||||
@@ -56,14 +55,12 @@ public:
|
||||
_buckets[GetIndex()].Value = value;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
int32 Count() const
|
||||
{
|
||||
int32 result = 0;
|
||||
for (int32 i = 0; i < MaxThreads; i++)
|
||||
{
|
||||
if (_buckets[i].ThreadID != 0)
|
||||
if (Platform::AtomicRead((int64 volatile*)&_buckets[i].ThreadID) != 0)
|
||||
result++;
|
||||
}
|
||||
return result;
|
||||
@@ -81,23 +78,25 @@ public:
|
||||
|
||||
protected:
|
||||
|
||||
FORCE_INLINE static int32 Hash(const uint64 value)
|
||||
FORCE_INLINE static int32 Hash(const int64 value)
|
||||
{
|
||||
return value & (MaxThreads - 1);
|
||||
}
|
||||
|
||||
FORCE_INLINE int32 GetIndex()
|
||||
{
|
||||
// TODO: fix it because now we can use only (MaxThreads-1) buckets
|
||||
ASSERT(Count() < MaxThreads);
|
||||
|
||||
auto key = Platform::GetCurrentThreadID();
|
||||
int64 key = (int64)Platform::GetCurrentThreadID();
|
||||
auto index = Hash(key);
|
||||
|
||||
while (_buckets[index].ThreadID != key && _buckets[index].ThreadID != 0)
|
||||
while (true)
|
||||
{
|
||||
const int64 value = Platform::AtomicRead(&_buckets[index].ThreadID);
|
||||
if (value == key)
|
||||
break;
|
||||
if (value == 0 && Platform::InterlockedCompareExchange(&_buckets[index].ThreadID, key, 0) == 0)
|
||||
break;
|
||||
index = Hash(index + 1);
|
||||
_buckets[index].ThreadID = key;
|
||||
|
||||
}
|
||||
return index;
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user