From 9291295a4d9997e3257037db7135d3a5e52d3962 Mon Sep 17 00:00:00 2001 From: Wojtek Figat Date: Sun, 10 Sep 2023 11:25:36 +0200 Subject: [PATCH] Fix `Dictionary` and `HashSet` iterators to prevent unwanted data copies #1361 --- Source/Engine/Content/Assets/VisualScript.cpp | 6 +-- Source/Engine/Core/Collections/Array.h | 8 ++-- Source/Engine/Core/Collections/ChunkedArray.h | 1 - Source/Engine/Core/Collections/Dictionary.h | 48 +++++++++++-------- Source/Engine/Core/Collections/HashSet.h | 46 +++++++++--------- 5 files changed, 57 insertions(+), 52 deletions(-) diff --git a/Source/Engine/Content/Assets/VisualScript.cpp b/Source/Engine/Content/Assets/VisualScript.cpp index e292a0133..698f1ca8c 100644 --- a/Source/Engine/Content/Assets/VisualScript.cpp +++ b/Source/Engine/Content/Assets/VisualScript.cpp @@ -1251,7 +1251,7 @@ void VisualScriptExecutor::ProcessGroupFlow(Box* boxBase, Node* node, Value& val boxBase = node->GetBox(3); if (boxBase->HasConnection()) eatBox(node, boxBase->FirstConnection()); - Dictionary::Iterator it(dictionary, iteratorValue.Value.AsInt); + Dictionary::Iterator it(&dictionary, iteratorValue.Value.AsInt); ++it; iteratorValue.Value.AsInt = it.Index(); } @@ -1269,12 +1269,12 @@ void VisualScriptExecutor::ProcessGroupFlow(Box* boxBase, Node* node, Value& val // Key case 1: if (iteratorIndex != scope->ReturnedValues.Count() && dictionaryIndex != scope->ReturnedValues.Count()) - value = Dictionary::Iterator(*scope->ReturnedValues[dictionaryIndex].Value.AsDictionary, scope->ReturnedValues[iteratorIndex].Value.AsInt)->Key; + value = Dictionary::Iterator(scope->ReturnedValues[dictionaryIndex].Value.AsDictionary, scope->ReturnedValues[iteratorIndex].Value.AsInt)->Key; break; // Value case 2: if (iteratorIndex != scope->ReturnedValues.Count() && dictionaryIndex != scope->ReturnedValues.Count()) - value = Dictionary::Iterator(*scope->ReturnedValues[dictionaryIndex].Value.AsDictionary, scope->ReturnedValues[iteratorIndex].Value.AsInt)->Value; + value = Dictionary::Iterator(scope->ReturnedValues[dictionaryIndex].Value.AsDictionary, scope->ReturnedValues[iteratorIndex].Value.AsInt)->Value; break; // Break case 5: diff --git a/Source/Engine/Core/Collections/Array.h b/Source/Engine/Core/Collections/Array.h index 3feae4e73..58117cf0a 100644 --- a/Source/Engine/Core/Collections/Array.h +++ b/Source/Engine/Core/Collections/Array.h @@ -938,12 +938,12 @@ public: FORCE_INLINE bool IsEnd() const { - return _index == _array->Count(); + return _index == _array->_count; } FORCE_INLINE bool IsNotEnd() const { - return _index != _array->Count(); + return _index != _array->_count; } FORCE_INLINE T& operator*() const @@ -975,7 +975,7 @@ public: Iterator& operator++() { - if (_index != _array->Count()) + if (_index != _array->_count) _index++; return *this; } @@ -983,7 +983,7 @@ public: Iterator operator++(int) { Iterator temp = *this; - if (_index != _array->Count()) + if (_index != _array->_count) _index++; return temp; } diff --git a/Source/Engine/Core/Collections/ChunkedArray.h b/Source/Engine/Core/Collections/ChunkedArray.h index 86473763b..d01711e38 100644 --- a/Source/Engine/Core/Collections/ChunkedArray.h +++ b/Source/Engine/Core/Collections/ChunkedArray.h @@ -95,7 +95,6 @@ public: struct Iterator { friend ChunkedArray; - private: ChunkedArray* _collection; int32 _chunkIndex; diff --git a/Source/Engine/Core/Collections/Dictionary.h b/Source/Engine/Core/Collections/Dictionary.h index 2327f3c24..01c25a842 100644 --- a/Source/Engine/Core/Collections/Dictionary.h +++ b/Source/Engine/Core/Collections/Dictionary.h @@ -237,22 +237,28 @@ public: { friend Dictionary; private: - Dictionary& _collection; + Dictionary* _collection; int32 _index; public: - Iterator(Dictionary& collection, const int32 index) + Iterator(Dictionary* collection, const int32 index) : _collection(collection) , _index(index) { } - Iterator(Dictionary const& collection, const int32 index) - : _collection((Dictionary&)collection) + Iterator(Dictionary const* collection, const int32 index) + : _collection(const_cast(collection)) , _index(index) { } + Iterator() + : _collection(nullptr) + , _index(-1) + { + } + Iterator(const Iterator& i) : _collection(i._collection) , _index(i._index) @@ -273,27 +279,27 @@ public: FORCE_INLINE bool IsEnd() const { - return _index == _collection._size; + return _index == _collection->_size; } FORCE_INLINE bool IsNotEnd() const { - return _index != _collection._size; + return _index != _collection->_size; } FORCE_INLINE Bucket& operator*() const { - return _collection._allocation.Get()[_index]; + return _collection->_allocation.Get()[_index]; } FORCE_INLINE Bucket* operator->() const { - return &_collection._allocation.Get()[_index]; + return &_collection->_allocation.Get()[_index]; } FORCE_INLINE explicit operator bool() const { - return _index >= 0 && _index < _collection._size; + return _index >= 0 && _index < _collection->_size; } FORCE_INLINE bool operator!() const @@ -320,10 +326,10 @@ public: Iterator& operator++() { - const int32 capacity = _collection.Capacity(); + const int32 capacity = _collection->_size; if (_index != capacity) { - const Bucket* data = _collection._allocation.Get(); + const Bucket* data = _collection->_allocation.Get(); do { _index++; @@ -343,7 +349,7 @@ public: { if (_index > 0) { - const Bucket* data = _collection._allocation.Get(); + const Bucket* data = _collection->_allocation.Get(); do { _index--; @@ -633,7 +639,7 @@ public: /// Iterator with key and value. void Add(const Iterator& i) { - ASSERT(&i._collection != this && i); + ASSERT(i._collection != this && i); const Bucket& bucket = *i; Add(bucket.Key, bucket.Value); } @@ -667,7 +673,7 @@ public: /// True if cannot remove item from the collection because cannot find it, otherwise false. bool Remove(const Iterator& i) { - ASSERT(&i._collection == this); + ASSERT(i._collection == this); if (i) { ASSERT(_allocation.Get()[i._index].IsOccupied()); @@ -711,7 +717,7 @@ public: return End(); FindPositionResult pos; FindPosition(key, pos); - return pos.ObjectIndex != -1 ? Iterator(*this, pos.ObjectIndex) : End(); + return pos.ObjectIndex != -1 ? Iterator(this, pos.ObjectIndex) : End(); } /// @@ -812,38 +818,38 @@ public: public: Iterator Begin() const { - Iterator i(*this, -1); + Iterator i(this, -1); ++i; return i; } Iterator End() const { - return Iterator(*this, _size); + return Iterator(this, _size); } Iterator begin() { - Iterator i(*this, -1); + Iterator i(this, -1); ++i; return i; } FORCE_INLINE Iterator end() { - return Iterator(*this, _size); + return Iterator(this, _size); } const Iterator begin() const { - Iterator i(*this, -1); + Iterator i(this, -1); ++i; return i; } FORCE_INLINE const Iterator end() const { - return Iterator(*this, _size); + return Iterator(this, _size); } protected: diff --git a/Source/Engine/Core/Collections/HashSet.h b/Source/Engine/Core/Collections/HashSet.h index ad6f8ffc6..107e42e65 100644 --- a/Source/Engine/Core/Collections/HashSet.h +++ b/Source/Engine/Core/Collections/HashSet.h @@ -213,17 +213,17 @@ public: { friend HashSet; private: - HashSet& _collection; + HashSet* _collection; int32 _index; - Iterator(HashSet& collection, const int32 index) + Iterator(HashSet* collection, const int32 index) : _collection(collection) , _index(index) { } - Iterator(HashSet const& collection, const int32 index) - : _collection((HashSet&)collection) + Iterator(HashSet const* collection, const int32 index) + : _collection(const_cast(collection)) , _index(index) { } @@ -244,27 +244,27 @@ public: public: FORCE_INLINE bool IsEnd() const { - return _index == _collection.Capacity(); + return _index == _collection->_size; } FORCE_INLINE bool IsNotEnd() const { - return _index != _collection.Capacity(); + return _index != _collection->_size; } FORCE_INLINE Bucket& operator*() const { - return _collection._allocation.Get()[_index]; + return _collection->_allocation.Get()[_index]; } FORCE_INLINE Bucket* operator->() const { - return &_collection._allocation.Get()[_index]; + return &_collection->_allocation.Get()[_index]; } FORCE_INLINE explicit operator bool() const { - return _index >= 0 && _index < _collection._size; + return _index >= 0 && _index < _collection->_size; } FORCE_INLINE bool operator !() const @@ -274,12 +274,12 @@ public: FORCE_INLINE bool operator==(const Iterator& v) const { - return _index == v._index && &_collection == &v._collection; + return _index == v._index && _collection == v._collection; } FORCE_INLINE bool operator!=(const Iterator& v) const { - return _index != v._index || &_collection != &v._collection; + return _index != v._index || _collection != v._collection; } Iterator& operator=(const Iterator& v) @@ -291,10 +291,10 @@ public: Iterator& operator++() { - const int32 capacity = _collection.Capacity(); + const int32 capacity = _collection->_size; if (_index != capacity) { - const Bucket* data = _collection._allocation.Get(); + const Bucket* data = _collection->_allocation.Get(); do { _index++; @@ -314,7 +314,7 @@ public: { if (_index > 0) { - const Bucket* data = _collection._allocation.Get(); + const Bucket* data = _collection->_allocation.Get(); do { _index--; @@ -464,7 +464,7 @@ public: /// Iterator with item to add void Add(const Iterator& i) { - ASSERT(&i._collection != this && i); + ASSERT(i._collection != this && i); const Bucket& bucket = *i; Add(bucket.Item); } @@ -498,7 +498,7 @@ public: /// True if cannot remove item from the collection because cannot find it, otherwise false. bool Remove(const Iterator& i) { - ASSERT(&i._collection == this); + ASSERT(i._collection == this); if (i) { ASSERT(_allocation.Get()[i._index].IsOccupied()); @@ -523,7 +523,7 @@ public: return End(); FindPositionResult pos; FindPosition(item, pos); - return pos.ObjectIndex != -1 ? Iterator(*this, pos.ObjectIndex) : End(); + return pos.ObjectIndex != -1 ? Iterator(this, pos.ObjectIndex) : End(); } /// @@ -559,38 +559,38 @@ public: public: Iterator Begin() const { - Iterator i(*this, -1); + Iterator i(this, -1); ++i; return i; } Iterator End() const { - return Iterator(*this, _size); + return Iterator(this, _size); } Iterator begin() { - Iterator i(*this, -1); + Iterator i(this, -1); ++i; return i; } FORCE_INLINE Iterator end() { - return Iterator(*this, _size); + return Iterator(this, _size); } const Iterator begin() const { - Iterator i(*this, -1); + Iterator i(this, -1); ++i; return i; } FORCE_INLINE const Iterator end() const { - return Iterator(*this, _size); + return Iterator(this, _size); } protected: