From 1b7a7dc15c74425fbadbe99c2a8e7e0281508733 Mon Sep 17 00:00:00 2001 From: Wojciech Figat Date: Wed, 16 Nov 2022 17:26:26 +0100 Subject: [PATCH] Add network RPCs to C# codegen --- .../Build/Plugins/NetworkingPlugin.cs | 377 +++++++++++++++++- 1 file changed, 372 insertions(+), 5 deletions(-) diff --git a/Source/Tools/Flax.Build/Build/Plugins/NetworkingPlugin.cs b/Source/Tools/Flax.Build/Build/Plugins/NetworkingPlugin.cs index 58dbf156e..b566271a8 100644 --- a/Source/Tools/Flax.Build/Build/Plugins/NetworkingPlugin.cs +++ b/Source/Tools/Flax.Build/Build/Plugins/NetworkingPlugin.cs @@ -36,7 +36,17 @@ namespace Flax.Build.Plugins public MethodDefinition Serialize; public MethodDefinition Deserialize; } - + + private struct MethodRPC + { + public TypeDefinition Type; + public MethodDefinition Method; + public bool IsServer; + public bool IsClient; + public int Channel; + public MethodDefinition Execute; + } + internal const string Network = "Network"; internal const string NetworkReplicated = "NetworkReplicated"; internal const string NetworkRpc = "NetworkRpc"; @@ -193,7 +203,7 @@ namespace Flax.Build.Plugins // Generated method thunk to execute RPC from network { - contents.Append(" static void ").Append(functionInfo.Name).AppendLine("_Execute(ScriptingObject* obj, NetworkStream* stream)"); + contents.Append(" static void ").Append(functionInfo.Name).AppendLine("_Execute(ScriptingObject* obj, NetworkStream* stream, void* tag)"); contents.AppendLine(" {"); string argNames = string.Empty; for (int i = 0; i < functionInfo.Parameters.Count; i++) @@ -243,6 +253,7 @@ namespace Flax.Build.Plugins contents.AppendLine($" info.Execute = {functionInfo.Name}_Execute;"); contents.AppendLine($" info.Invoke = {functionInfo.Name}_Invoke;"); contents.AppendLine($" info.Channel = (uint8)NetworkChannelType::{channelType};"); + contents.AppendLine($" info.Tag = nullptr;"); contents.AppendLine(" return info;"); contents.AppendLine(" }"); } @@ -494,6 +505,7 @@ namespace Flax.Build.Plugins bool modified = false; bool failed = false; var addSerializers = new List(); + var methodRPCs = new List(); foreach (TypeDefinition type in module.Types) { if (type.IsInterface || type.IsEnum) @@ -501,6 +513,21 @@ namespace Flax.Build.Plugins var isNative = type.HasAttribute("FlaxEngine.UnmanagedAttribute"); if (isNative) continue; + + // Generate RPCs + var methods = type.Methods; + var methodsCount = methods.Count; // methods list can be modified during RPCs generation + for (int i = 0; i < methodsCount; i++) + { + MethodDefinition method = methods[i]; + var attribute = method.CustomAttributes.FirstOrDefault(x => x.AttributeType.FullName == "FlaxEngine.NetworkRpcAttribute"); + if (attribute != null) + { + GenerateDotNetRPCBody(type, method, attribute, ref failed, networkStreamType, methodRPCs); + } + } + + // Generate serializers if (type.HasMethod(Thunk1) || type.HasMethod(Thunk2)) continue; var isINetworkSerializable = type.HasInterface("FlaxEngine.Networking.INetworkSerializable"); @@ -560,7 +587,7 @@ namespace Flax.Build.Plugins return; // Generate serializers initializer (invoked on module load) - if (addSerializers.Count != 0) + if (addSerializers.Count != 0 || methodRPCs.Count != 0) { // Create class var name = "Initializer"; @@ -590,6 +617,10 @@ namespace Flax.Build.Plugins module.ImportReference(addSerializer); var serializeFuncType = addSerializer.Parameters[1].ParameterType; var serializeFuncCtor = serializeFuncType.Resolve().GetMethod(".ctor"); + var addRPC = networkReplicatorType.Resolve().GetMethod("AddRPC", 6); + module.ImportReference(addRPC); + var executeRPCFuncType = addRPC.Parameters[2].ParameterType; + var executeRPCFuncCtor = executeRPCFuncType.Resolve().GetMethod(".ctor"); foreach (var e in addSerializers) { // NetworkReplicator.AddSerializer(typeof(), .INetworkSerializable_SerializeNative, .INetworkSerializable_DeserializeNative); @@ -603,6 +634,20 @@ namespace Flax.Build.Plugins il.Emit(OpCodes.Newobj, module.ImportReference(serializeFuncCtor)); il.Emit(OpCodes.Call, module.ImportReference(addSerializer)); } + foreach (var e in methodRPCs) + { + // NetworkReplicator.AddRPC(typeof(), "", _Execute, , , ); + il.Emit(OpCodes.Ldtoken, e.Type); + il.Emit(OpCodes.Call, module.ImportReference(getTypeFromHandle)); + il.Emit(OpCodes.Ldstr, e.Method.Name); + il.Emit(OpCodes.Ldnull); + il.Emit(OpCodes.Ldftn, e.Execute); + il.Emit(OpCodes.Newobj, module.ImportReference(executeRPCFuncCtor)); + il.Emit(OpCodes.Ldc_I4, e.IsServer ? 1 : 0); + il.Emit(OpCodes.Ldc_I4, e.IsClient ? 1 : 0); + il.Emit(OpCodes.Ldc_I4, e.Channel); + il.Emit(OpCodes.Call, module.ImportReference(addRPC)); + } il.Emit(OpCodes.Nop); il.Emit(OpCodes.Ret); c.Methods.Add(m); @@ -774,8 +819,8 @@ namespace Flax.Build.Plugins var getID = scriptingObjectType.Resolve().GetMethod("get_ID"); il.Emit(OpCodes.Call, module.ImportReference(getID)); il.Append(jmp2); - var m = networkStreamType.GetMethod("WriteGuid"); - il.Emit(OpCodes.Callvirt, module.ImportReference(m)); + var writeGuid = networkStreamType.GetMethod("WriteGuid"); + il.Emit(OpCodes.Callvirt, module.ImportReference(writeGuid)); } else { @@ -824,6 +869,7 @@ namespace Flax.Build.Plugins else if (valueType.IsValueType) { // Invoke structure generated serializer + // TODO: check if this type has generated serialization code il.Emit(OpCodes.Ldarg_0); il.Emit(OpCodes.Ldflda, field); il.Emit(OpCodes.Ldarg_1); @@ -970,5 +1016,326 @@ namespace Flax.Build.Plugins failed = true; } } + + private static void GenerateDotNetRPCSerializerType(TypeDefinition type, bool serialize, ref bool failed, int localIndex, TypeReference valueType, ILProcessor il, TypeDefinition networkStreamType, int streamLocalIndex, Instruction ilStart) + { + ModuleDefinition module = type.Module; + TypeDefinition valueTypeDef = valueType.Resolve(); + if (_inBuildSerializers.TryGetValue(valueType.FullName, out var serializer)) + { + // Call NetworkStream method to write/read data + if (serialize) + { + il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, streamLocalIndex)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldarg, localIndex)); + il.InsertBefore(ilStart, il.Create(OpCodes.Callvirt, module.ImportReference(networkStreamType.GetMethod(serializer.WriteMethod)))); + } + else + { + il.Emit(OpCodes.Ldloc_1); + il.Emit(OpCodes.Callvirt, module.ImportReference(networkStreamType.GetMethod(serializer.ReadMethod))); + il.Emit(OpCodes.Stloc, localIndex); + } + } + else if (valueType.IsScriptingObject()) + { + // Replicate ScriptingObject as Guid ID + module.GetType("System.Guid", out var guidType); + module.GetType("FlaxEngine.Object", out var scriptingObjectType); + if (serialize) + { + il.InsertBefore(ilStart, il.Create(OpCodes.Nop)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, streamLocalIndex)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldarg, localIndex)); + Instruction jmp1 = il.Create(OpCodes.Nop); + il.InsertBefore(ilStart, il.Create(OpCodes.Brtrue_S, jmp1)); + var guidEmpty = guidType.Resolve().GetField("Empty"); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldsfld, module.ImportReference(guidEmpty))); + Instruction jmp2 = il.Create(OpCodes.Nop); + il.InsertBefore(ilStart, il.Create(OpCodes.Br_S, jmp2)); + il.InsertBefore(ilStart, jmp1); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldarg, localIndex)); + var getID = scriptingObjectType.Resolve().GetMethod("get_ID"); + il.InsertBefore(ilStart, il.Create(OpCodes.Call, module.ImportReference(getID))); + il.InsertBefore(ilStart, jmp2); + var writeGuid = networkStreamType.GetMethod("WriteGuid"); + il.InsertBefore(ilStart, il.Create(OpCodes.Callvirt, module.ImportReference(writeGuid))); + } + else + { + var m = networkStreamType.GetMethod("ReadGuid"); + module.GetType("System.Type", out var typeType); + il.Emit(OpCodes.Ldloc_1); + il.Emit(OpCodes.Callvirt, module.ImportReference(m)); + var varStart = il.Body.Variables.Count; + var reference = module.ImportReference(guidType); + reference.IsValueType = true; // Fix locals init to have valuetype for Guid instead of class + il.Body.Variables.Add(new VariableDefinition(reference)); + il.Body.InitLocals = true; + il.Emit(OpCodes.Stloc_S, (byte)varStart); + il.Emit(OpCodes.Ldloca_S, (byte)varStart); + il.Emit(OpCodes.Ldtoken, valueType); + var getTypeFromHandle = typeType.Resolve().GetMethod("GetTypeFromHandle"); + il.Emit(OpCodes.Call, module.ImportReference(getTypeFromHandle)); + var tryFind = scriptingObjectType.Resolve().GetMethod("TryFind", 2); + il.Emit(OpCodes.Call, module.ImportReference(tryFind)); + il.Emit(OpCodes.Castclass, valueType); + il.Emit(OpCodes.Stloc, localIndex); + } + } + else if (valueTypeDef.IsEnum) + { + // Replicate enum as bits + // TODO: use smaller uint depending on enum values range + if (serialize) + { + il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, streamLocalIndex)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldarg, localIndex)); + il.InsertBefore(ilStart, il.Create(OpCodes.Callvirt, module.ImportReference(networkStreamType.GetMethod("WriteUInt32")))); + } + else + { + il.Emit(OpCodes.Ldloc_1); + var m = networkStreamType.GetMethod("ReadUInt32"); + il.Emit(OpCodes.Callvirt, module.ImportReference(m)); + il.Emit(OpCodes.Stloc, localIndex); + } + } + else if (valueType.IsValueType) + { + // Invoke structure generated serializer + // TODO: check if this type has generated serialization code + if (serialize) + { + il.InsertBefore(ilStart, il.Create(OpCodes.Nop)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldarga, localIndex)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, streamLocalIndex)); + il.InsertBefore(ilStart, il.Create(OpCodes.Call, module.ImportReference(valueTypeDef.GetMethod(Thunk1)))); + } + else + { + il.Emit(OpCodes.Ldloca_S, (byte)localIndex); + il.Emit(OpCodes.Initobj, valueType); + il.Emit(OpCodes.Ldloca_S, (byte)localIndex); + il.Emit(OpCodes.Ldloc_1); + il.Emit(OpCodes.Call, module.ImportReference(valueTypeDef.GetMethod(Thunk2))); + } + } + else + { + // Unknown type + Log.Error($"Not supported type '{valueType.FullName}' for RPC parameter in {type.FullName}."); + failed = true; + } + } + + private static void GenerateDotNetRPCBody(TypeDefinition type, MethodDefinition method, CustomAttribute attribute, ref bool failed, TypeReference networkStreamType, List methodRPCs) + { + // Validate RPC usage + if (method.IsAbstract) + { + Log.Error($"Not supported abstract RPC method '{method.FullName}'."); + failed = true; + return; + } + if (method.IsVirtual) + { + Log.Error($"Not supported virtual RPC method '{method.FullName}'."); + failed = true; + return; + } + ModuleDefinition module = type.Module; + var voidType = module.TypeSystem.Void; + if (method.ReturnType != voidType) + { + Log.Error($"Not supported non-void RPC method '{method.FullName}'."); + failed = true; + return; + } + if (method.IsStatic) + { + Log.Error($"Not supported static RPC method '{method.FullName}'."); + failed = true; + return; + } + var methodRPC = new MethodRPC(); + methodRPC.Type = type; + methodRPC.Method = method; + methodRPC.IsServer = (bool)attribute.GetFieldValue("Server", false); + methodRPC.IsClient = (bool)attribute.GetFieldValue("Client", false); + if (methodRPC.IsServer && methodRPC.IsClient) + { + Log.Error($"Network RPC {method.Name} in {type.FullName} cannot be both Server and Client."); + failed = true; + return; + } + if (!methodRPC.IsServer && !methodRPC.IsClient) + { + Log.Error($"Network RPC {method.Name} in {type.FullName} needs to have Server or Client specifier."); + failed = true; + return; + } + methodRPC.Channel = (int)attribute.GetFieldValue("Channel", 4); // int as NetworkChannelType (default is ReliableOrdered=4) + module.GetType("System.IntPtr", out var intPtrType); + module.GetType("FlaxEngine.Object", out var scriptingObjectType); + var fromUnmanagedPtr = scriptingObjectType.Resolve().GetMethod("FromUnmanagedPtr"); + TypeReference networkStream = module.ImportReference(networkStreamType); + + // Generate static method to execute RPC locally + { + var m = new MethodDefinition(method.Name + "_Execute", MethodAttributes.Static | MethodAttributes.Private | MethodAttributes.HideBySig, voidType); + m.Parameters.Add(new ParameterDefinition("instancePtr", ParameterAttributes.None, intPtrType)); + m.Parameters.Add(new ParameterDefinition("streamPtr", ParameterAttributes.None, module.ImportReference(intPtrType))); + ILProcessor il = m.Body.GetILProcessor(); + il.Emit(OpCodes.Nop); + il.Body.InitLocals = true; + + // instance = ()FlaxEngine.Object.FromUnmanagedPtr(instancePtr) + il.Body.Variables.Add(new VariableDefinition(type)); + il.Emit(OpCodes.Ldarg_0); + il.Emit(OpCodes.Call, module.ImportReference(fromUnmanagedPtr)); + il.Emit(OpCodes.Castclass, type); + il.Emit(OpCodes.Stloc_0); + + // NetworkStream stream = (NetworkStream)FlaxEngine.Object.FromUnmanagedPtr(streamPtr) + il.Body.Variables.Add(new VariableDefinition(networkStream)); + il.Emit(OpCodes.Ldarg_1); + il.Emit(OpCodes.Call, module.ImportReference(fromUnmanagedPtr)); + il.Emit(OpCodes.Castclass, networkStream); + il.Emit(OpCodes.Stloc_1); + + // Add locals for each RPC parameter + var argsStart = il.Body.Variables.Count; + for (int i = 0; i < method.Parameters.Count; i++) + { + var parameter = method.Parameters[i]; + if (parameter.IsOut) + { + Log.Error($"Network RPC {method.Name} in {type.FullName} parameter {parameter.Name} cannot be 'out'."); + failed = true; + return; + } + var parameterType = parameter.ParameterType; + il.Body.Variables.Add(new VariableDefinition(parameterType)); + } + + // Deserialize parameters from the stream + for (int i = 0; i < method.Parameters.Count; i++) + { + var parameter = method.Parameters[i]; + var parameterType = parameter.ParameterType; + GenerateDotNetRPCSerializerType(type, false, ref failed, argsStart + i, parameterType, il, networkStream.Resolve(), 1, null); + } + + // Call RPC method body + il.Emit(OpCodes.Ldloc_0); + for (int i = 0; i < method.Parameters.Count; i++) + { + il.Emit(OpCodes.Ldloc, argsStart + i); + } + il.Emit(OpCodes.Callvirt, method); + + il.Emit(OpCodes.Nop); + il.Emit(OpCodes.Ret); + type.Methods.Add(m); + methodRPC.Execute = m; + } + + // Inject custom code before RPC method body to invoke it + { + ILProcessor il = method.Body.GetILProcessor(); + Instruction ilStart = il.Body.Instructions[0]; + module.GetType("System.Boolean", out var boolType); + module.GetType("FlaxEngine.Networking.NetworkManagerMode", out var networkManagerModeType); + module.GetType("FlaxEngine.Networking.NetworkManager", out var networkManagerType); + var networkManagerGetMode = networkManagerType.Resolve().GetMethod("get_Mode", 0); + il.Body.InitLocals = true; + var varsStart = il.Body.Variables.Count; + + // Is Server/Is Client boolean constants + il.Body.Variables.Add(new VariableDefinition(module.ImportReference(boolType))); // [0] + il.Body.Variables.Add(new VariableDefinition(module.ImportReference(boolType))); // [1] + il.InsertBefore(ilStart, il.Create(OpCodes.Ldc_I4, methodRPC.IsServer ? 1 : 0)); + il.InsertBefore(ilStart, il.Create(OpCodes.Stloc, varsStart + 0)); // isServer loc=0 + il.InsertBefore(ilStart, il.Create(OpCodes.Ldc_I4, methodRPC.IsClient ? 1 : 0)); + il.InsertBefore(ilStart, il.Create(OpCodes.Stloc, varsStart + 1)); // isClient loc=1 + + // NetworkManagerMode mode = NetworkManager.Mode; + il.Body.Variables.Add(new VariableDefinition(module.ImportReference(networkManagerModeType))); // [2] + il.InsertBefore(ilStart, il.Create(OpCodes.Call, module.ImportReference(networkManagerGetMode))); + il.InsertBefore(ilStart, il.Create(OpCodes.Stloc, varsStart + 2)); // mode loc=2 + + // if ((server && networkMode == NetworkManagerMode.Client) || (client && networkMode != NetworkManagerMode.Client)) + var jumpIfBodyStart = il.Create(OpCodes.Nop); // if block body + var jumpIf2Start = il.Create(OpCodes.Nop); // 2nd part of the if + var jumpBodyStart = il.Create(OpCodes.Nop); // original method body start + var jumpBodyEnd = il.Body.Instructions.First(x => x.OpCode == OpCodes.Ret); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, varsStart + 0)); + il.InsertBefore(ilStart, il.Create(OpCodes.Brfalse_S, jumpIf2Start)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, varsStart + 2)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldc_I4_2)); + il.InsertBefore(ilStart, il.Create(OpCodes.Beq_S, jumpIfBodyStart)); + // || + il.InsertBefore(ilStart, jumpIf2Start); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, varsStart + 1)); + il.InsertBefore(ilStart, il.Create(OpCodes.Brfalse_S, jumpBodyStart)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, varsStart + 2)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldc_I4_2)); + il.InsertBefore(ilStart, il.Create(OpCodes.Beq_S, jumpBodyStart)); + // { + il.InsertBefore(ilStart, jumpIfBodyStart); + + // NetworkStream stream = NetworkReplicator.BeginInvokeRPC(); + il.Body.Variables.Add(new VariableDefinition(module.ImportReference(networkStream))); // [3] + var streamLocalIndex = varsStart + 3; + module.GetType("FlaxEngine.Networking.NetworkReplicator", out var networkReplicatorType); + var beginInvokeRPC = networkReplicatorType.Resolve().GetMethod("BeginInvokeRPC", 0); + il.InsertBefore(ilStart, il.Create(OpCodes.Call, module.ImportReference(beginInvokeRPC))); + il.InsertBefore(ilStart, il.Create(OpCodes.Stloc, streamLocalIndex)); // stream loc=3 + + // Serialize all RPC parameters + for (int i = 0; i < method.Parameters.Count; i++) + { + var parameter = method.Parameters[i]; + var parameterType = parameter.ParameterType; + GenerateDotNetRPCSerializerType(type, true, ref failed, i + 1, parameterType, il, networkStream.Resolve(), streamLocalIndex, ilStart); + } + + // NetworkReplicator.EndInvokeRPC(this, typeof(), "", stream); + il.InsertBefore(ilStart, il.Create(OpCodes.Nop)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldarg_0)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldtoken, type)); + module.GetType("System.Type", out var typeType); + var getTypeFromHandle = typeType.Resolve().GetMethod("GetTypeFromHandle"); + il.InsertBefore(ilStart, il.Create(OpCodes.Call, module.ImportReference(getTypeFromHandle))); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldstr, method.Name)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, streamLocalIndex)); + var endInvokeRPC = networkReplicatorType.Resolve().GetMethod("EndInvokeRPC", 4); + il.InsertBefore(ilStart, il.Create(OpCodes.Call, module.ImportReference(endInvokeRPC))); + + // if (server && networkMode == NetworkManagerMode.Client) return; + if (methodRPC.IsServer) + { + il.InsertBefore(ilStart, il.Create(OpCodes.Nop)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, varsStart + 2)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldc_I4_2)); + il.InsertBefore(ilStart, il.Create(OpCodes.Beq_S, jumpBodyEnd)); + } + + // if (client && networkMode == NetworkManagerMode.Server) return; + if (methodRPC.IsClient) + { + il.InsertBefore(ilStart, il.Create(OpCodes.Nop)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, varsStart + 2)); + il.InsertBefore(ilStart, il.Create(OpCodes.Ldc_I4_1)); + il.InsertBefore(ilStart, il.Create(OpCodes.Beq_S, jumpBodyEnd)); + } + + // Continue to original method body + il.InsertBefore(ilStart, jumpBodyStart); + } + + methodRPCs.Add(methodRPC); + } } }