// Copyright (c) 2012-2023 Wojciech Figat. All rights reserved. #pragma once #if PLATFORM_WIN32 #include "Engine/Core/Templates.h" #include "../Win32/IncludeWindowsHeaders.h" /// /// Helper object that makes IUnknown methods private. /// template class RemoveIUnknownBase : public T { public: ~RemoveIUnknownBase() { } // STDMETHOD macro implies virtual. // ComPtr can be used with any class that implements the 3 methods of IUnknown. // ComPtr does not require these methods to be virtual. // When ComPtr is used with a class without a virtual table, marking the functions // as virtual in this class adds unnecessary overhead. HRESULT __stdcall QueryInterface(REFIID riid, _COM_Outptr_ void** ppvObject); ULONG __stdcall AddRef(); ULONG __stdcall Release(); template HRESULT __stdcall QueryInterface(_COM_Outptr_ Q** pp) { return QueryInterface(__uuidof(Q), (void**)pp); } }; template struct RemoveIUnknown { typedef RemoveIUnknownBase ReturnType; }; template struct RemoveIUnknown { typedef const RemoveIUnknownBase ReturnType; }; template class ComPtr { public: typedef T InterfaceType; protected: InterfaceType* ptr_; template friend class ComPtr; void InternalAddRef() const { if (ptr_ != nullptr) { ptr_->AddRef(); } } unsigned long InternalRelease() { unsigned long ref = 0; T* temp = ptr_; if (temp != nullptr) { ptr_ = nullptr; ref = temp->Release(); } return ref; } public: ComPtr() : ptr_(nullptr) { } ComPtr(decltype(__nullptr)) : ptr_(nullptr) { } template ComPtr(U* other) : ptr_(other) { InternalAddRef(); } ComPtr(const ComPtr& other) : ptr_(other.ptr_) { InternalAddRef(); } // copy constructor that allows to instantiate class when U* is convertible to T* template ComPtr(const ComPtr& other, typename TEnableIf<__is_convertible_to(U*, T*), void*>::Type* = 0) : ptr_(other.ptr_) { InternalAddRef(); } ComPtr(ComPtr&& other) noexcept : ptr_(nullptr) { if (this != reinterpret_cast(&reinterpret_cast(other))) { Swap(other); } } // Move constructor that allows instantiation of a class when U* is convertible to T* template ComPtr(ComPtr&& other, typename TEnableIf<__is_convertible_to(U*, T*), void*>::Type* = 0) : ptr_(other.ptr_) { other.ptr_ = nullptr; } ~ComPtr() { InternalRelease(); } public: ComPtr& operator=(decltype(__nullptr)) { InternalRelease(); return *this; } ComPtr& operator=(T* other) { if (ptr_ != other) { ComPtr(other).Swap(*this); } return *this; } template ComPtr& operator=(U* other) { ComPtr(other).Swap(*this); return *this; } ComPtr& operator=(const ComPtr& other) { if (ptr_ != other.ptr_) { ComPtr(other).Swap(*this); } return *this; } template ComPtr& operator=(const ComPtr& other) { ComPtr(other).Swap(*this); return *this; } ComPtr& operator=(ComPtr&& other) noexcept { ComPtr(static_cast(other)).Swap(*this); return *this; } template ComPtr& operator=(ComPtr&& other) { ComPtr(static_cast&&>(other)).Swap(*this); return *this; } public: void Swap(ComPtr&& r) { T* tmp = ptr_; ptr_ = r.ptr_; r.ptr_ = tmp; } void Swap(ComPtr& r) { T* tmp = ptr_; ptr_ = r.ptr_; r.ptr_ = tmp; } operator bool() const { return Get() != nullptr; } operator T*() const { return ptr_; } T* Get() const { return ptr_; } #if BUILD_DEBUG typename RemoveIUnknown::ReturnType* operator->() const { return static_cast::ReturnType*>(ptr_); } #else typename InterfaceType* operator->() const { return ptr_; } #endif T** operator&() { return &ptr_; } const T** operator&() const { return &ptr_; } T* const* GetAddressOf() const { return &ptr_; } T** GetAddressOf() { return &ptr_; } T** ReleaseAndGetAddressOf() { InternalRelease(); return &ptr_; } T* Detach() { T* ptr = ptr_; ptr_ = nullptr; return ptr; } void Attach(InterfaceType* other) { if (ptr_ != nullptr) { auto ref = ptr_->Release(); } ptr_ = other; } unsigned long Reset() { return InternalRelease(); } // query for U interface template HRESULT As(ComPtr* p) const { return ptr_->QueryInterface(__uuidof(U), reinterpret_cast(p->ReleaseAndGetAddressOf())); } }; // Comparison operators - don't compare COM object identity template bool operator==(const ComPtr& a, const ComPtr& b) { static_assert(__is_base_of(T, U) || __is_base_of(U, T), "'T' and 'U' pointers must be comparable"); return a.Get() == b.Get(); } template bool operator==(const ComPtr& a, decltype(__nullptr)) { return a.Get() == nullptr; } template bool operator==(decltype(__nullptr), const ComPtr& a) { return a.Get() == nullptr; } template bool operator!=(const ComPtr& a, const ComPtr& b) { static_assert(__is_base_of(T, U) || __is_base_of(U, T), "'T' and 'U' pointers must be comparable"); return a.Get() != b.Get(); } template bool operator!=(const ComPtr& a, decltype(__nullptr)) { return a.Get() != nullptr; } template bool operator!=(decltype(__nullptr), const ComPtr& a) { return a.Get() != nullptr; } template bool operator<(const ComPtr& a, const ComPtr& b) { static_assert(__is_base_of(T, U) || __is_base_of(U, T), "'T' and 'U' pointers must be comparable"); return a.Get() < b.Get(); } #endif