diff --git a/Source/Engine/Networking/NetworkInternal.h b/Source/Engine/Networking/NetworkInternal.h index f65c4a5e9..d063a0ca3 100644 --- a/Source/Engine/Networking/NetworkInternal.h +++ b/Source/Engine/Networking/NetworkInternal.h @@ -13,6 +13,7 @@ enum class NetworkMessageIDs : uint8 ObjectSpawn, ObjectDespawn, ObjectRole, + ObjectRpc, MAX, }; @@ -29,4 +30,5 @@ public: static void OnNetworkMessageObjectSpawn(NetworkEvent& event, NetworkClient* client, NetworkPeer* peer); static void OnNetworkMessageObjectDespawn(NetworkEvent& event, NetworkClient* client, NetworkPeer* peer); static void OnNetworkMessageObjectRole(NetworkEvent& event, NetworkClient* client, NetworkPeer* peer); + static void OnNetworkMessageObjectRpc(NetworkEvent& event, NetworkClient* client, NetworkPeer* peer); }; diff --git a/Source/Engine/Networking/NetworkManager.cpp b/Source/Engine/Networking/NetworkManager.cpp index 04d4eb860..72634230a 100644 --- a/Source/Engine/Networking/NetworkManager.cpp +++ b/Source/Engine/Networking/NetworkManager.cpp @@ -132,6 +132,7 @@ namespace NetworkInternal::OnNetworkMessageObjectSpawn, NetworkInternal::OnNetworkMessageObjectDespawn, NetworkInternal::OnNetworkMessageObjectRole, + NetworkInternal::OnNetworkMessageObjectRpc, }; } diff --git a/Source/Engine/Networking/NetworkReplicator.cpp b/Source/Engine/Networking/NetworkReplicator.cpp index 214b26687..ac0d0c72e 100644 --- a/Source/Engine/Networking/NetworkReplicator.cpp +++ b/Source/Engine/Networking/NetworkReplicator.cpp @@ -9,6 +9,7 @@ #include "NetworkPeer.h" #include "NetworkChannelType.h" #include "NetworkEvent.h" +#include "NetworkRpc.h" #include "INetworkSerializable.h" #include "INetworkObject.h" #include "Engine/Core/Log.h" @@ -71,6 +72,15 @@ PACK_STRUCT(struct NetworkMessageObjectRole uint32 OwnerClientId; }); +PACK_STRUCT(struct NetworkMessageObjectRpc + { + NetworkMessageIDs ID = NetworkMessageIDs::ObjectRpc; + Guid ObjectId; + char RpcTypeName[128]; // TODO: introduce networked-name to synchronize unique names as ushort (less data over network) + char RpcName[128]; // TODO: introduce networked-name to synchronize unique names as ushort (less data over network) + uint16 ArgsSize; + }); + struct NetworkReplicatedObject { ScriptingObjectReference Object; @@ -128,18 +138,30 @@ struct SpawnItem NetworkObjectRole Role; }; +struct RpcItem +{ + ScriptingObjectReference Object; + NetworkRpcName Name; + NetworkRpcInfo Info; + BytesContainer ArgsData; +}; + namespace { CriticalSection ObjectsLock; HashSet Objects; Array SpawnQueue; Array DespawnQueue; + Array RpcQueue; Dictionary IdsRemappingTable; NetworkStream* CachedWriteStream = nullptr; NetworkStream* CachedReadStream = nullptr; Array NewClients; Array CachedTargets; Dictionary SerializersTable; +#if !COMPILE_WITHOUT_CSHARP + Dictionary CSharpCachedNames; +#endif } class NetworkReplicationService : public EngineService @@ -156,6 +178,9 @@ public: void NetworkReplicationService::Dispose() { NetworkInternal::NetworkReplicatorClear(); +#if !COMPILE_WITHOUT_CSHARP + CSharpCachedNames.ClearDelete(); +#endif } NetworkReplicationService NetworkReplicationServiceInstance; @@ -270,6 +295,12 @@ FORCE_INLINE void BuildCachedTargets(const NetworkReplicatedObject& item) BuildCachedTargets(NetworkManager::Clients, item.TargetClientIds, item.OwnerClientId); } +FORCE_INLINE void GetNetworkName(char buffer[128], const StringAnsiView& name) +{ + Platform::MemoryCopy(buffer, name.Get(), name.Length()); + buffer[name.Length()] = 0; +} + void SendObjectSpawnMessage(const NetworkReplicatedObject& item, ScriptingObject* obj) { NetworkMessageObjectSpawn msgData; @@ -291,9 +322,7 @@ void SendObjectSpawnMessage(const NetworkReplicatedObject& item, ScriptingObject msgData.PrefabObjectID = objScene->GetPrefabObjectID(); } msgData.OwnerClientId = item.OwnerClientId; - const StringAnsiView& objectTypeName = obj->GetType().Fullname; - Platform::MemoryCopy(msgData.ObjectTypeName, objectTypeName.Get(), objectTypeName.Length()); - msgData.ObjectTypeName[objectTypeName.Length()] = 0; + GetNetworkName(msgData.ObjectTypeName, obj->GetType().Fullname); auto* peer = NetworkManager::Peer; NetworkMessage msg = peer->BeginSendMessage(); msg.WriteStructure(msgData); @@ -368,6 +397,48 @@ void NetworkReplicator::AddSerializer(const ScriptingTypeHandle& typeHandle, con AddSerializer(typeHandle, INetworkSerializable_Managed, INetworkSerializable_Managed, (void*)*(SerializeFunc*)&serialize, (void*)*(SerializeFunc*)&deserialize); } +void RPC_Execute_Managed(ScriptingObject* obj, NetworkStream* stream, void* tag) +{ + auto signature = (Function::Signature)tag; + signature(obj, stream); +} + +void NetworkReplicator::AddRPC(const ScriptingTypeHandle& typeHandle, const StringAnsiView& name, const Function& execute, bool isServer, bool isClient, NetworkChannelType channel) +{ + if (!typeHandle) + return; + + const NetworkRpcName rpcName(typeHandle, GetCSharpCachedName(name)); + + NetworkRpcInfo rpcInfo; + rpcInfo.Server = isServer; + rpcInfo.Client = isClient; + rpcInfo.Channel = (uint8)channel; + rpcInfo.Invoke = nullptr; // C# RPCs invoking happens on C# side (build-time code generation) + rpcInfo.Execute = RPC_Execute_Managed; + rpcInfo.Tag = (void*)*(SerializeFunc*)&execute; + + // Add to the global RPCs table + NetworkRpcInfo::RPCsTable[rpcName] = rpcInfo; +} + +void NetworkReplicator::CSharpEndInvokeRPC(ScriptingObject* obj, const ScriptingTypeHandle& type, const StringAnsiView& name, NetworkStream* argsStream) +{ + EndInvokeRPC(obj, type, GetCSharpCachedName(name), argsStream); +} + +StringAnsiView NetworkReplicator::GetCSharpCachedName(const StringAnsiView& name) +{ + // Cache method name on a heap to support C# hot-reloads (also glue code from C# passes view to the stack-only text so cache it here) + StringAnsi* result; + if (!CSharpCachedNames.TryGet(name, result)) + { + result = New(name); + CSharpCachedNames.Add(StringAnsiView(*result), result); + } + return StringAnsiView(*result); +} + #endif void NetworkReplicator::AddSerializer(const ScriptingTypeHandle& typeHandle, SerializeFunc serialize, SerializeFunc deserialize, void* serializeTag, void* deserializeTag) @@ -610,6 +681,32 @@ void NetworkReplicator::DirtyObject(ScriptingObject* obj) // TODO: implement objects state replication frequency and dirtying } +Dictionary NetworkRpcInfo::RPCsTable; + +NetworkStream* NetworkReplicator::BeginInvokeRPC() +{ + if (CachedWriteStream == nullptr) + CachedWriteStream = New(); + CachedWriteStream->Initialize(); + return CachedWriteStream; +} + +void NetworkReplicator::EndInvokeRPC(ScriptingObject* obj, const ScriptingTypeHandle& type, const StringAnsiView& name, NetworkStream* argsStream) +{ + const NetworkRpcInfo* info = NetworkRpcInfo::RPCsTable.TryGet(NetworkRpcName(type, name)); + if (!info || !obj) + return; + ObjectsLock.Lock(); + auto& rpc = RpcQueue.AddOne(); + rpc.Object = obj; + rpc.Name.First = type; + rpc.Name.Second = name; + rpc.Info = *info; + const Span argsData(argsStream->GetBuffer(), argsStream->GetPosition()); + rpc.ArgsData.Copy(argsData); + ObjectsLock.Unlock(); +} + void NetworkInternal::NetworkReplicatorClientConnected(NetworkClient* client) { ScopeLock lock(ObjectsLock); @@ -685,6 +782,7 @@ void NetworkInternal::NetworkReplicatorUpdate() if (CachedWriteStream == nullptr) CachedWriteStream = New(); const bool isClient = NetworkManager::IsClient(); + const bool isServer = NetworkManager::IsServer(); NetworkStream* stream = CachedWriteStream; NetworkPeer* peer = NetworkManager::Peer; @@ -763,6 +861,8 @@ void NetworkInternal::NetworkReplicatorUpdate() for (auto& e : SpawnQueue) { ScriptingObject* obj = e.Object.Get(); + if (!obj) + continue; auto it = Objects.Find(obj->GetID()); if (it == Objects.End()) { @@ -848,9 +948,7 @@ void NetworkInternal::NetworkReplicatorUpdate() IdsRemappingTable.KeyOf(msgData.ObjectId, &msgData.ObjectId); IdsRemappingTable.KeyOf(msgData.ParentId, &msgData.ParentId); } - const StringAnsiView& objectTypeName = obj->GetType().Fullname; - Platform::MemoryCopy(msgData.ObjectTypeName, objectTypeName.Get(), objectTypeName.Length()); - msgData.ObjectTypeName[objectTypeName.Length()] = 0; + GetNetworkName(msgData.ObjectTypeName, obj->GetType().Fullname); msgData.DataSize = size; // TODO: split object data (eg. more messages) if needed NetworkMessage msg = peer->BeginSendMessage(); @@ -869,6 +967,47 @@ void NetworkInternal::NetworkReplicatorUpdate() } } + // Invoke RPCs + for (auto& e : RpcQueue) + { + ScriptingObject* obj = e.Object.Get(); + if (!obj) + continue; + auto it = Objects.Find(obj->GetID()); + if (it == Objects.End()) + continue; + auto& item = it->Item; + + // Send despawn message + //NETWORK_REPLICATOR_LOG(Info, "[NetworkReplicator] Rpc {}::{} object ID={}", e.Name.First.ToString(), String(e.Name.Second), item.ToString()); + NetworkMessageObjectRpc msgData; + msgData.ObjectId = item.ObjectId; + if (isClient) + { + // Remap local client object ids into server ids + IdsRemappingTable.KeyOf(msgData.ObjectId, &msgData.ObjectId); + } + GetNetworkName(msgData.RpcTypeName, e.Name.First.GetType().Fullname); + GetNetworkName(msgData.RpcName, e.Name.Second); + msgData.ArgsSize = (uint16)e.ArgsData.Length(); + NetworkMessage msg = peer->BeginSendMessage(); + msg.WriteStructure(msgData); + msg.WriteBytes(e.ArgsData.Get(), e.ArgsData.Length()); + NetworkChannelType channel = (NetworkChannelType)e.Info.Channel; + if (e.Info.Server && isClient) + { + // Client -> Server + peer->EndSendMessage(channel, msg); + } + else if (e.Info.Client && isServer) + { + // Server -> Client(s) + BuildCachedTargets(item); + peer->EndSendMessage(channel, msg, CachedTargets); + } + } + RpcQueue.Clear(); + // Clear networked objects mapping table Scripting::ObjectsLookupIdMapping.Set(nullptr); } @@ -1019,7 +1158,7 @@ void NetworkInternal::OnNetworkMessageObjectSpawn(NetworkEvent& event, NetworkCl else { // Spawn object - const ScriptingTypeHandle objectType = Scripting::FindScriptingType(StringAnsiView(msgData.ObjectTypeName)); + const ScriptingTypeHandle objectType = Scripting::FindScriptingType(msgData.ObjectTypeName); obj = ScriptingObject::NewObject(objectType); if (!obj) { @@ -1139,3 +1278,54 @@ void NetworkInternal::OnNetworkMessageObjectRole(NetworkEvent& event, NetworkCli NETWORK_REPLICATOR_LOG(Error, "[NetworkReplicator] Unknown object role update {}", msgData.ObjectId); } } + +void NetworkInternal::OnNetworkMessageObjectRpc(NetworkEvent& event, NetworkClient* client, NetworkPeer* peer) +{ + NetworkMessageObjectRpc msgData; + event.Message.ReadStructure(msgData); + ScopeLock lock(ObjectsLock); + NetworkReplicatedObject* e = ResolveObject(msgData.ObjectId); + if (e) + { + auto& item = *e; + ScriptingObject* obj = item.Object.Get(); + if (!obj) + return; + + // Find RPC info + NetworkRpcName name; + name.First = Scripting::FindScriptingType(msgData.RpcTypeName); + name.Second = msgData.RpcName; + const NetworkRpcInfo* info = NetworkRpcInfo::RPCsTable.TryGet(name); + if (!info) + { + NETWORK_REPLICATOR_LOG(Error, "[NetworkReplicator] Unknown object {} RPC {}::{}", msgData.ObjectId, String(msgData.RpcTypeName), String(msgData.RpcName)); + return; + } + + // Validate RPC + if (info->Server && NetworkManager::IsClient()) + { + NETWORK_REPLICATOR_LOG(Error, "[NetworkReplicator] Cannot invoke server RPC {}::{} on client", String(msgData.RpcTypeName), String(msgData.RpcName)); + return; + } + if (info->Client && NetworkManager::IsServer()) + { + NETWORK_REPLICATOR_LOG(Error, "[NetworkReplicator] Cannot invoke client RPC {}::{} on server", String(msgData.RpcTypeName), String(msgData.RpcName)); + return; + } + + // Setup message reading stream + if (CachedReadStream == nullptr) + CachedReadStream = New(); + NetworkStream* stream = CachedReadStream; + stream->Initialize(event.Message.Buffer + event.Message.Position, msgData.ArgsSize); + + // Execute RPC + info->Execute(obj, stream, info->Tag); + } + else + { + NETWORK_REPLICATOR_LOG(Error, "[NetworkReplicator] Unknown object {} RPC {}::{}", msgData.ObjectId, String(msgData.RpcTypeName), String(msgData.RpcName)); + } +} diff --git a/Source/Engine/Networking/NetworkReplicator.cs b/Source/Engine/Networking/NetworkReplicator.cs index 6c746e7b9..b730f9403 100644 --- a/Source/Engine/Networking/NetworkReplicator.cs +++ b/Source/Engine/Networking/NetworkReplicator.cs @@ -30,6 +30,16 @@ namespace FlaxEngine.Networking /// var stream = (NetworkStream)Object.FromUnmanagedPtr(streamPtr) public delegate void SerializeFunc(IntPtr instancePtr, IntPtr streamPtr); + /// + /// Network RPC executing delegate. + /// + /// + /// Use Object.FromUnmanagedPtr(objPtr/streamPtr) to get object or NetworkStream from raw native pointers. + /// + /// var instance = Object.FromUnmanagedPtr(instancePtr) + /// var stream = (NetworkStream)Object.FromUnmanagedPtr(streamPtr) + public delegate void ExecuteRPCFunc(IntPtr instancePtr, IntPtr streamPtr); + /// /// Registers a new serialization methods for a given C# type. /// @@ -93,5 +103,35 @@ namespace FlaxEngine.Networking } return Internal_InvokeSerializer(type, instance, FlaxEngine.Object.GetUnmanagedPtr(stream), serialize); } + + /// + /// Ends invoking the RPC. + /// + /// The target object to invoke RPC. + /// The RPC type. + /// The RPC name. + /// The RPC serialized arguments stream returned from BeginInvokeRPC. + [Unmanaged] + public static void EndInvokeRPC(Object obj, Type type, string name, NetworkStream argsStream) + { + Internal_CSharpEndInvokeRPC(FlaxEngine.Object.GetUnmanagedPtr(obj), type, name, FlaxEngine.Object.GetUnmanagedPtr(argsStream)); + } + + /// + /// Registers a RPC method for a given C# method. + /// + /// The C# type (FlaxEngine.Object). + /// The RPC method name (from that type). + /// Function to call for RPC execution. + /// Server RPC. + /// Client RPC. + /// Network channel to use for RPC transport. + [Unmanaged] + public static void AddRPC(Type type, string name, ExecuteRPCFunc execute, bool isServer = true, bool isClient = false, NetworkChannelType channel = NetworkChannelType.ReliableOrdered) + { + if (!typeof(FlaxEngine.Object).IsAssignableFrom(type)) + throw new ArgumentException("Not supported type for RPC. Only FlaxEngine.Object types are valid."); + Internal_AddRPC(type, name, Marshal.GetFunctionPointerForDelegate(execute), isServer, isClient, channel); + } } } diff --git a/Source/Engine/Networking/NetworkReplicator.h b/Source/Engine/Networking/NetworkReplicator.h index 237d69f84..cc527630e 100644 --- a/Source/Engine/Networking/NetworkReplicator.h +++ b/Source/Engine/Networking/NetworkReplicator.h @@ -154,8 +154,27 @@ public: /// The network object. API_FUNCTION() static void DirtyObject(ScriptingObject* obj); +public: + /// + /// Begins invoking the RPC and returns the Network Stream to serialize parameters to. + /// + /// Network Stream to write RPC parameters to. + API_FUNCTION() static NetworkStream* BeginInvokeRPC(); + + /// + /// Ends invoking the RPC. + /// + /// The target object to invoke RPC. + /// The RPC type. + /// The RPC name. + /// The RPC serialized arguments stream returned from BeginInvokeRPC. + static void EndInvokeRPC(ScriptingObject* obj, const ScriptingTypeHandle& type, const StringAnsiView& name, NetworkStream* argsStream); + private: #if !COMPILE_WITHOUT_CSHARP - API_FUNCTION(NoProxy) static void AddSerializer(const ScriptingTypeHandle& type, const Function& serialize, const Function& deserialize); + API_FUNCTION(NoProxy) static void AddSerializer(const ScriptingTypeHandle& typeHandle, const Function& serialize, const Function& deserialize); + API_FUNCTION(NoProxy) static void AddRPC(const ScriptingTypeHandle& typeHandle, const StringAnsiView& name, const Function& execute, bool isServer, bool isClient, NetworkChannelType channel); + API_FUNCTION(NoProxy) static void CSharpEndInvokeRPC(ScriptingObject* obj, const ScriptingTypeHandle& type, const StringAnsiView& name, NetworkStream* argsStream); + static StringAnsiView GetCSharpCachedName(const StringAnsiView& name); #endif }; diff --git a/Source/Engine/Networking/NetworkRpc.h b/Source/Engine/Networking/NetworkRpc.h new file mode 100644 index 000000000..df3cec09e --- /dev/null +++ b/Source/Engine/Networking/NetworkRpc.h @@ -0,0 +1,76 @@ +// Copyright (c) 2012-2022 Wojciech Figat. All rights reserved. + +#pragma once + +#include "Engine/Core/Types/StringView.h" +#include "Engine/Core/Types/Pair.h" +#include "Engine/Core/Collections/Array.h" +#include "Engine/Core/Collections/Dictionary.h" +#include "Engine/Scripting/ScriptingType.h" + +class NetworkStream; + +// Network RPC identifier name (pair of type and function name) +typedef Pair NetworkRpcName; + +// Network RPC descriptor +struct FLAXENGINE_API NetworkRpcInfo +{ + uint8 Server : 1; + uint8 Client : 1; + uint8 Channel : 4; + void (*Execute)(ScriptingObject* obj, NetworkStream* stream, void* tag); + void (*Invoke)(ScriptingObject* obj, void** args); + void* Tag; + + /// + /// Global table for registered RPCs. Key: pair of type, RPC name. Value: RPC descriptor. + /// + static Dictionary RPCsTable; +}; + +// Gets the pointer to the RPC argument into the args buffer +template +FORCE_INLINE void NetworkRpcInitArg(Array>& args, const T& v) +{ + args.Add((void*)&v); +} + +// Gets the pointers to the RPC arguments into the args buffer +template +FORCE_INLINE void NetworkRpcInitArg(Array>& args, const T& first, Params&... params) +{ + NetworkRpcInitArg(args, first); + NetworkRpcInitArg(args, Forward(params)...); +} + +// Network RPC implementation (placed in the beginning of the method body) +#define NETWORK_RPC_IMPL(type, name, ...) \ + { \ + const NetworkRpcInfo& rpcInfo = NetworkRpcInfo::RPCsTable[NetworkRpcName(type::TypeInitializer, StringAnsiView(#name))]; \ + const NetworkManagerMode networkMode = NetworkManager::Mode; \ + if ((rpcInfo.Server && networkMode == NetworkManagerMode::Client) || (rpcInfo.Client && networkMode != NetworkManagerMode::Client)) \ + { \ + Array> args; \ + NetworkRpcInitArg(args, __VA_ARGS__); \ + rpcInfo.Invoke(this, args.Get()); \ + if (rpcInfo.Server && networkMode == NetworkManagerMode::Client) \ + return; \ + if (rpcInfo.Client && networkMode == NetworkManagerMode::Server) \ + return; \ + } \ + } + +// Network RPC override implementation (placed in the beginning of the overriden method body - after call to the base class method) +#define NETWORK_RPC_OVERRIDE_IMPL(type, name, ...) \ + { \ + const NetworkRpcInfo& rpcInfo = NetworkRpcInfo::RPCsTable[NetworkRpcName(type::TypeInitializer, StringAnsiView(#name))]; \ + const NetworkManagerMode networkMode = NetworkManager::Mode; \ + if ((rpcInfo.Server && networkMode == NetworkManagerMode::Client) || (rpcInfo.Client && networkMode != NetworkManagerMode::Client)) \ + { \ + if (rpcInfo.Server && networkMode == NetworkManagerMode::Client) \ + return; \ + if (rpcInfo.Client && networkMode == NetworkManagerMode::Server) \ + return; \ + } \ + } diff --git a/Source/Engine/Scripting/Attributes/NetworkRpcAttribute.cs b/Source/Engine/Scripting/Attributes/NetworkRpcAttribute.cs new file mode 100644 index 000000000..8543a6006 --- /dev/null +++ b/Source/Engine/Scripting/Attributes/NetworkRpcAttribute.cs @@ -0,0 +1,42 @@ +// Copyright (c) 2012-2022 Wojciech Figat. All rights reserved. + +using System; +using FlaxEngine.Networking; + +namespace FlaxEngine +{ + /// + /// Indicates that a method is Remote Procedure Call which can be invoked on client and executed on server or invoked on server and executed on clients. + /// + [AttributeUsage(AttributeTargets.Method)] + public sealed class NetworkRpcAttribute : Attribute + { + /// + /// True if RPC should be executed on server. + /// + public bool Server; + + /// + /// True if RPC should be executed on client. + /// + public bool Client; + + /// + /// Network channel using which RPC should be send. + /// + public NetworkChannelType Channel; + + /// + /// Initializes a new instance of the class. + /// + /// True if RPC should be executed on server. + /// True if RPC should be executed on client. + /// Network channel using which RPC should be send. + public NetworkRpcAttribute(bool server = false, bool client = false, NetworkChannelType channel = NetworkChannelType.ReliableOrdered) + { + Server = server; + Client = client; + Channel = channel; + } + } +} diff --git a/Source/Tools/Flax.Build/Build/Plugins/NetworkingPlugin.cs b/Source/Tools/Flax.Build/Build/Plugins/NetworkingPlugin.cs index f2b33d5a5..58dbf156e 100644 --- a/Source/Tools/Flax.Build/Build/Plugins/NetworkingPlugin.cs +++ b/Source/Tools/Flax.Build/Build/Plugins/NetworkingPlugin.cs @@ -36,8 +36,10 @@ namespace Flax.Build.Plugins public MethodDefinition Serialize; public MethodDefinition Deserialize; } - + + internal const string Network = "Network"; internal const string NetworkReplicated = "NetworkReplicated"; + internal const string NetworkRpc = "NetworkRpc"; private const string Thunk1 = "INetworkSerializable_Serialize"; private const string Thunk2 = "INetworkSerializable_Deserialize"; private static readonly Dictionary _inBuildSerializers = new Dictionary() @@ -78,34 +80,42 @@ namespace Flax.Build.Plugins private void OnParseMemberTag(ref bool valid, BindingsGenerator.TagParameter tag, MemberInfo memberInfo) { - if (tag.Tag != NetworkReplicated) - return; - - // Mark member as replicated - valid = true; - memberInfo.SetTag(NetworkReplicated, string.Empty); + if (tag.Tag == NetworkReplicated) + { + // Mark member as replicated + valid = true; + memberInfo.SetTag(NetworkReplicated, string.Empty); + } + else if (tag.Tag == NetworkRpc) + { + // Mark member as rpc + valid = true; + memberInfo.SetTag(NetworkRpc, tag.Value); + } } private void OnGenerateCppTypeInternals(Builder.BuildData buildData, ApiTypeInfo typeInfo, StringBuilder contents) { // Skip modules that don't use networking var module = BindingsGenerator.CurrentModule; - if (module.GetTag(NetworkReplicated) == null) + if (module.GetTag(Network) == null) return; - // Check if type uses automated network replication + // Check if type uses automated network replication/RPCs List fields = null; List properties = null; + List functions = null; if (typeInfo is ClassInfo classInfo) { fields = classInfo.Fields; properties = classInfo.Properties; + functions = classInfo.Functions; } else if (typeInfo is StructureInfo structInfo) { fields = structInfo.Fields; } - bool useReplication = false; + bool useReplication = false, useRpc = false; if (fields != null) { foreach (var fieldInfo in fields) @@ -128,16 +138,117 @@ namespace Flax.Build.Plugins } } } - if (!useReplication) - return; - - typeInfo.SetTag(NetworkReplicated, string.Empty); + if (functions != null) + { + foreach (var functionInfo in functions) + { + if (functionInfo.GetTag(NetworkRpc) != null) + { + useRpc = true; + break; + } + } + } + if (useReplication) + { + typeInfo.SetTag(NetworkReplicated, string.Empty); - // Generate C++ wrapper functions to serialize/deserialize type - BindingsGenerator.CppIncludeFiles.Add("Engine/Networking/NetworkReplicator.h"); - BindingsGenerator.CppIncludeFiles.Add("Engine/Networking/NetworkStream.h"); - OnGenerateCppTypeSerialize(buildData, typeInfo, contents, fields, properties, true); - OnGenerateCppTypeSerialize(buildData, typeInfo, contents, fields, properties, false); + // Generate C++ wrapper functions to serialize/deserialize type + BindingsGenerator.CppIncludeFiles.Add("Engine/Networking/NetworkReplicator.h"); + BindingsGenerator.CppIncludeFiles.Add("Engine/Networking/NetworkStream.h"); + OnGenerateCppTypeSerialize(buildData, typeInfo, contents, fields, properties, true); + OnGenerateCppTypeSerialize(buildData, typeInfo, contents, fields, properties, false); + } + if (useRpc) + { + typeInfo.SetTag(NetworkRpc, string.Empty); + + // Generate C++ wrapper functions to invoke/execute RPC + BindingsGenerator.CppIncludeFiles.Add("Engine/Networking/NetworkStream.h"); + BindingsGenerator.CppIncludeFiles.Add("Engine/Networking/NetworkReplicator.h"); + BindingsGenerator.CppIncludeFiles.Add("Engine/Networking/NetworkChannelType.h"); + BindingsGenerator.CppIncludeFiles.Add("Engine/Networking/NetworkRpc.h"); + foreach (var functionInfo in functions) + { + var tag = functionInfo.GetTag(NetworkRpc); + if (tag == null) + continue; + if (functionInfo.UniqueName != functionInfo.Name) + throw new Exception($"Invalid network RPC method {functionInfo.Name} name in type {typeInfo.Name}. Network RPC functions names have to be unique."); + bool isServer = tag.IndexOf("Server", StringComparison.OrdinalIgnoreCase) != -1; + bool isClient = tag.IndexOf("Client", StringComparison.OrdinalIgnoreCase) != -1; + if (isServer && isClient) + throw new Exception($"Network RPC {functionInfo.Name} in {typeInfo.Name} cannot be both Server and Client."); + if (!isServer && !isClient) + throw new Exception($"Network RPC {functionInfo.Name} in {typeInfo.Name} needs to have Server or Client specifier."); + var channelType = "ReliableOrdered"; + if (tag.IndexOf("UnreliableOrdered", StringComparison.OrdinalIgnoreCase) != -1) + channelType = "UnreliableOrdered"; + else if (tag.IndexOf("ReliableOrdered", StringComparison.OrdinalIgnoreCase) != -1) + channelType = "ReliableOrdered"; + else if (tag.IndexOf("Unreliable", StringComparison.OrdinalIgnoreCase) != -1) + channelType = "Unreliable"; + else if (tag.IndexOf("Reliable", StringComparison.OrdinalIgnoreCase) != -1) + channelType = "Reliable"; + + // Generated method thunk to execute RPC from network + { + contents.Append(" static void ").Append(functionInfo.Name).AppendLine("_Execute(ScriptingObject* obj, NetworkStream* stream)"); + contents.AppendLine(" {"); + string argNames = string.Empty; + for (int i = 0; i < functionInfo.Parameters.Count; i++) + { + var arg = functionInfo.Parameters[i]; + if (i != 0) + argNames += ", "; + argNames += arg.Name; + + // Deserialize arguments + contents.AppendLine($" {arg.Type.Type} {arg.Name};"); + contents.AppendLine($" stream->Read({arg.Name});"); + } + + // Call method locally + contents.AppendLine($" ASSERT(obj && obj->Is<{typeInfo.NativeName}>());"); + contents.AppendLine($" (({typeInfo.NativeName}*)obj)->{functionInfo.Name}({argNames});"); + contents.AppendLine(" }"); + } + contents.AppendLine(); + + // Generated method thunk to invoke RPC to network + { + contents.Append(" static void ").Append(functionInfo.Name).AppendLine("_Invoke(ScriptingObject* obj, void** args)"); + contents.AppendLine(" {"); + contents.AppendLine(" NetworkStream* stream = NetworkReplicator::BeginInvokeRPC();"); + for (int i = 0; i < functionInfo.Parameters.Count; i++) + { + var arg = functionInfo.Parameters[i]; + + // Serialize arguments + contents.AppendLine($" stream->Write(*({arg.Type.Type}*)args[{i}]);"); + } + + // Invoke RPC + contents.AppendLine($" NetworkReplicator::EndInvokeRPC(obj, {typeInfo.NativeName}::TypeInitializer, StringAnsiView(\"{functionInfo.Name}\", {functionInfo.Name.Length}), stream);"); + contents.AppendLine(" }"); + } + contents.AppendLine(); + + // Generated info about RPC implementation + { + contents.Append(" static NetworkRpcInfo ").Append(functionInfo.Name).AppendLine("_Info()"); + contents.AppendLine(" {"); + contents.AppendLine(" NetworkRpcInfo info;"); + contents.AppendLine($" info.Server = {(isServer ? "1" : "0")};"); + contents.AppendLine($" info.Execute = {functionInfo.Name}_Execute;"); + contents.AppendLine($" info.Invoke = {functionInfo.Name}_Invoke;"); + contents.AppendLine($" info.Channel = (uint8)NetworkChannelType::{channelType};"); + contents.AppendLine(" return info;"); + contents.AppendLine(" }"); + } + contents.AppendLine(); + } + } } private void OnGenerateCppTypeSerialize(Builder.BuildData buildData, ApiTypeInfo typeInfo, StringBuilder contents, List fields, List properties, bool serialize) @@ -265,13 +376,37 @@ namespace Flax.Build.Plugins private void OnGenerateCppTypeInitRuntime(Builder.BuildData buildData, ApiTypeInfo typeInfo, StringBuilder contents) { - if (typeInfo.GetTag(NetworkReplicated) == null) + // Skip types that don't use networking + var replicatedTag = typeInfo.GetTag(NetworkReplicated); + var rpcTag = typeInfo.GetTag(NetworkRpc); + if (replicatedTag == null && rpcTag == null) return; var typeNameNative = typeInfo.FullNameNative; var typeNameInternal = typeInfo.FullNameNativeInternal; - // Register generated serializer functions - contents.AppendLine($" NetworkReplicator::AddSerializer(ScriptingTypeHandle({typeNameNative}::TypeInitializer), {typeNameInternal}Internal::INetworkSerializable_Serialize, {typeNameInternal}Internal::INetworkSerializable_Deserialize);"); + if (replicatedTag != null) + { + // Register generated serializer functions + contents.AppendLine($" NetworkReplicator::AddSerializer(ScriptingTypeHandle({typeNameNative}::TypeInitializer), {typeNameInternal}Internal::INetworkSerializable_Serialize, {typeNameInternal}Internal::INetworkSerializable_Deserialize);"); + } + if (rpcTag != null) + { + // Register generated RPCs + List functions = null; + if (typeInfo is ClassInfo classInfo) + { + functions = classInfo.Functions; + } + if (functions != null) + { + foreach (var functionInfo in functions) + { + if (functionInfo.GetTag(NetworkRpc) == null) + continue; + contents.AppendLine($" NetworkRpcInfo::RPCsTable[NetworkRpcName({typeNameNative}::TypeInitializer, StringAnsiView(\"{functionInfo.Name}\", {functionInfo.Name.Length}))] = {functionInfo.Name}_Info();"); + } + } + } } private void OnGenerateCSharpTypeInternals(Builder.BuildData buildData, ApiTypeInfo typeInfo, StringBuilder contents, string indent) @@ -318,7 +453,7 @@ namespace Flax.Build.Plugins private void OnBuildDotNetAssembly(TaskGraph graph, Builder.BuildData buildData, NativeCpp.BuildOptions buildOptions, Task buildTask, IGrouping binaryModule) { // Skip assemblies not using netowrking - if (!binaryModule.Any(module => module.Tags.ContainsKey(NetworkReplicated))) + if (!binaryModule.Any(module => module.Tags.ContainsKey(Network))) return; // Generate netoworking code inside assembly after it's being compiled @@ -326,7 +461,7 @@ namespace Flax.Build.Plugins var task = graph.Add(); task.ProducedFiles.Add(assemblyPath); task.WorkingDirectory = buildTask.WorkingDirectory; - task.Command = () => OnPatchAssembly(buildData, buildOptions, buildTask, assemblyPath); + task.Command = () => OnPatchDotNetAssembly(buildData, buildOptions, buildTask, assemblyPath); task.CommandPath = null; task.InfoMessage = $"Generating netowrking code for {Path.GetFileName(assemblyPath)}..."; task.Cost = 50; @@ -335,7 +470,7 @@ namespace Flax.Build.Plugins task.DependentTasks.Add(buildTask); } - private void OnPatchAssembly(Builder.BuildData buildData, NativeCpp.BuildOptions buildOptions, Task buildTask, string assemblyPath) + private void OnPatchDotNetAssembly(Builder.BuildData buildData, NativeCpp.BuildOptions buildOptions, Task buildTask, string assemblyPath) { using (DefaultAssemblyResolver assemblyResolver = new DefaultAssemblyResolver()) using (AssemblyDefinition assembly = AssemblyDefinition.ReadAssembly(assemblyPath, new ReaderParameters{ ReadWrite = true, ReadSymbols = true, AssemblyResolver = assemblyResolver })) @@ -391,8 +526,6 @@ namespace Flax.Build.Plugins isNetworkReplicated = true; break; } - - if (type.IsValueType) { if (isINetworkSerializable) diff --git a/Source/Tools/Flax.Build/Utilities/MonoCecil.cs b/Source/Tools/Flax.Build/Utilities/MonoCecil.cs index 506ed3efb..6bedab2e6 100644 --- a/Source/Tools/Flax.Build/Utilities/MonoCecil.cs +++ b/Source/Tools/Flax.Build/Utilities/MonoCecil.cs @@ -41,7 +41,7 @@ namespace Flax.Build public static MethodDefinition GetMethod(this TypeDefinition type, string name, int argCount) { - var result = type.Methods.First(x => x.Name == name && x.Parameters.Count == argCount); + var result = type.Methods.FirstOrDefault(x => x.Name == name && x.Parameters.Count == argCount); if (result == null) throw new Exception($"Failed to find method '{name}' (args={argCount}) in '{type.FullName}'."); return result; @@ -49,12 +49,32 @@ namespace Flax.Build public static FieldDefinition GetField(this TypeDefinition type, string name) { - var result = type.Fields.First(x => x.Name == name); + var result = type.Fields.FirstOrDefault(x => x.Name == name); if (result == null) throw new Exception($"Failed to find field '{name}' in '{type.FullName}'."); return result; } + public static CustomAttributeNamedArgument GetField(this CustomAttribute attribute, string name) + { + foreach (var f in attribute.Fields) + { + if (f.Name == name) + return f; + } + throw new Exception($"Failed to find field '{name}' in '{attribute.AttributeType.FullName}'."); + } + + public static object GetFieldValue(this CustomAttribute attribute, string name, object defaultValue) + { + foreach (var f in attribute.Fields) + { + if (f.Name == name) + return f.Argument.Value; + } + return defaultValue; + } + public static bool IsScriptingObject(this TypeDefinition type) { if (type.FullName == "FlaxEngine.Object")