Add network RPCs to C# codegen

This commit is contained in:
Wojciech Figat
2022-11-16 17:26:26 +01:00
parent efb48697fa
commit 1b7a7dc15c

View File

@@ -36,7 +36,17 @@ namespace Flax.Build.Plugins
public MethodDefinition Serialize;
public MethodDefinition Deserialize;
}
private struct MethodRPC
{
public TypeDefinition Type;
public MethodDefinition Method;
public bool IsServer;
public bool IsClient;
public int Channel;
public MethodDefinition Execute;
}
internal const string Network = "Network";
internal const string NetworkReplicated = "NetworkReplicated";
internal const string NetworkRpc = "NetworkRpc";
@@ -193,7 +203,7 @@ namespace Flax.Build.Plugins
// Generated method thunk to execute RPC from network
{
contents.Append(" static void ").Append(functionInfo.Name).AppendLine("_Execute(ScriptingObject* obj, NetworkStream* stream)");
contents.Append(" static void ").Append(functionInfo.Name).AppendLine("_Execute(ScriptingObject* obj, NetworkStream* stream, void* tag)");
contents.AppendLine(" {");
string argNames = string.Empty;
for (int i = 0; i < functionInfo.Parameters.Count; i++)
@@ -243,6 +253,7 @@ namespace Flax.Build.Plugins
contents.AppendLine($" info.Execute = {functionInfo.Name}_Execute;");
contents.AppendLine($" info.Invoke = {functionInfo.Name}_Invoke;");
contents.AppendLine($" info.Channel = (uint8)NetworkChannelType::{channelType};");
contents.AppendLine($" info.Tag = nullptr;");
contents.AppendLine(" return info;");
contents.AppendLine(" }");
}
@@ -494,6 +505,7 @@ namespace Flax.Build.Plugins
bool modified = false;
bool failed = false;
var addSerializers = new List<TypeSerializer>();
var methodRPCs = new List<MethodRPC>();
foreach (TypeDefinition type in module.Types)
{
if (type.IsInterface || type.IsEnum)
@@ -501,6 +513,21 @@ namespace Flax.Build.Plugins
var isNative = type.HasAttribute("FlaxEngine.UnmanagedAttribute");
if (isNative)
continue;
// Generate RPCs
var methods = type.Methods;
var methodsCount = methods.Count; // methods list can be modified during RPCs generation
for (int i = 0; i < methodsCount; i++)
{
MethodDefinition method = methods[i];
var attribute = method.CustomAttributes.FirstOrDefault(x => x.AttributeType.FullName == "FlaxEngine.NetworkRpcAttribute");
if (attribute != null)
{
GenerateDotNetRPCBody(type, method, attribute, ref failed, networkStreamType, methodRPCs);
}
}
// Generate serializers
if (type.HasMethod(Thunk1) || type.HasMethod(Thunk2))
continue;
var isINetworkSerializable = type.HasInterface("FlaxEngine.Networking.INetworkSerializable");
@@ -560,7 +587,7 @@ namespace Flax.Build.Plugins
return;
// Generate serializers initializer (invoked on module load)
if (addSerializers.Count != 0)
if (addSerializers.Count != 0 || methodRPCs.Count != 0)
{
// Create class
var name = "Initializer";
@@ -590,6 +617,10 @@ namespace Flax.Build.Plugins
module.ImportReference(addSerializer);
var serializeFuncType = addSerializer.Parameters[1].ParameterType;
var serializeFuncCtor = serializeFuncType.Resolve().GetMethod(".ctor");
var addRPC = networkReplicatorType.Resolve().GetMethod("AddRPC", 6);
module.ImportReference(addRPC);
var executeRPCFuncType = addRPC.Parameters[2].ParameterType;
var executeRPCFuncCtor = executeRPCFuncType.Resolve().GetMethod(".ctor");
foreach (var e in addSerializers)
{
// NetworkReplicator.AddSerializer(typeof(<type>), <type>.INetworkSerializable_SerializeNative, <type>.INetworkSerializable_DeserializeNative);
@@ -603,6 +634,20 @@ namespace Flax.Build.Plugins
il.Emit(OpCodes.Newobj, module.ImportReference(serializeFuncCtor));
il.Emit(OpCodes.Call, module.ImportReference(addSerializer));
}
foreach (var e in methodRPCs)
{
// NetworkReplicator.AddRPC(typeof(<type>), "<name>", <name>_Execute, <isServer>, <isClient>, <channel>);
il.Emit(OpCodes.Ldtoken, e.Type);
il.Emit(OpCodes.Call, module.ImportReference(getTypeFromHandle));
il.Emit(OpCodes.Ldstr, e.Method.Name);
il.Emit(OpCodes.Ldnull);
il.Emit(OpCodes.Ldftn, e.Execute);
il.Emit(OpCodes.Newobj, module.ImportReference(executeRPCFuncCtor));
il.Emit(OpCodes.Ldc_I4, e.IsServer ? 1 : 0);
il.Emit(OpCodes.Ldc_I4, e.IsClient ? 1 : 0);
il.Emit(OpCodes.Ldc_I4, e.Channel);
il.Emit(OpCodes.Call, module.ImportReference(addRPC));
}
il.Emit(OpCodes.Nop);
il.Emit(OpCodes.Ret);
c.Methods.Add(m);
@@ -774,8 +819,8 @@ namespace Flax.Build.Plugins
var getID = scriptingObjectType.Resolve().GetMethod("get_ID");
il.Emit(OpCodes.Call, module.ImportReference(getID));
il.Append(jmp2);
var m = networkStreamType.GetMethod("WriteGuid");
il.Emit(OpCodes.Callvirt, module.ImportReference(m));
var writeGuid = networkStreamType.GetMethod("WriteGuid");
il.Emit(OpCodes.Callvirt, module.ImportReference(writeGuid));
}
else
{
@@ -824,6 +869,7 @@ namespace Flax.Build.Plugins
else if (valueType.IsValueType)
{
// Invoke structure generated serializer
// TODO: check if this type has generated serialization code
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Ldflda, field);
il.Emit(OpCodes.Ldarg_1);
@@ -970,5 +1016,326 @@ namespace Flax.Build.Plugins
failed = true;
}
}
private static void GenerateDotNetRPCSerializerType(TypeDefinition type, bool serialize, ref bool failed, int localIndex, TypeReference valueType, ILProcessor il, TypeDefinition networkStreamType, int streamLocalIndex, Instruction ilStart)
{
ModuleDefinition module = type.Module;
TypeDefinition valueTypeDef = valueType.Resolve();
if (_inBuildSerializers.TryGetValue(valueType.FullName, out var serializer))
{
// Call NetworkStream method to write/read data
if (serialize)
{
il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, streamLocalIndex));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldarg, localIndex));
il.InsertBefore(ilStart, il.Create(OpCodes.Callvirt, module.ImportReference(networkStreamType.GetMethod(serializer.WriteMethod))));
}
else
{
il.Emit(OpCodes.Ldloc_1);
il.Emit(OpCodes.Callvirt, module.ImportReference(networkStreamType.GetMethod(serializer.ReadMethod)));
il.Emit(OpCodes.Stloc, localIndex);
}
}
else if (valueType.IsScriptingObject())
{
// Replicate ScriptingObject as Guid ID
module.GetType("System.Guid", out var guidType);
module.GetType("FlaxEngine.Object", out var scriptingObjectType);
if (serialize)
{
il.InsertBefore(ilStart, il.Create(OpCodes.Nop));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, streamLocalIndex));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldarg, localIndex));
Instruction jmp1 = il.Create(OpCodes.Nop);
il.InsertBefore(ilStart, il.Create(OpCodes.Brtrue_S, jmp1));
var guidEmpty = guidType.Resolve().GetField("Empty");
il.InsertBefore(ilStart, il.Create(OpCodes.Ldsfld, module.ImportReference(guidEmpty)));
Instruction jmp2 = il.Create(OpCodes.Nop);
il.InsertBefore(ilStart, il.Create(OpCodes.Br_S, jmp2));
il.InsertBefore(ilStart, jmp1);
il.InsertBefore(ilStart, il.Create(OpCodes.Ldarg, localIndex));
var getID = scriptingObjectType.Resolve().GetMethod("get_ID");
il.InsertBefore(ilStart, il.Create(OpCodes.Call, module.ImportReference(getID)));
il.InsertBefore(ilStart, jmp2);
var writeGuid = networkStreamType.GetMethod("WriteGuid");
il.InsertBefore(ilStart, il.Create(OpCodes.Callvirt, module.ImportReference(writeGuid)));
}
else
{
var m = networkStreamType.GetMethod("ReadGuid");
module.GetType("System.Type", out var typeType);
il.Emit(OpCodes.Ldloc_1);
il.Emit(OpCodes.Callvirt, module.ImportReference(m));
var varStart = il.Body.Variables.Count;
var reference = module.ImportReference(guidType);
reference.IsValueType = true; // Fix locals init to have valuetype for Guid instead of class
il.Body.Variables.Add(new VariableDefinition(reference));
il.Body.InitLocals = true;
il.Emit(OpCodes.Stloc_S, (byte)varStart);
il.Emit(OpCodes.Ldloca_S, (byte)varStart);
il.Emit(OpCodes.Ldtoken, valueType);
var getTypeFromHandle = typeType.Resolve().GetMethod("GetTypeFromHandle");
il.Emit(OpCodes.Call, module.ImportReference(getTypeFromHandle));
var tryFind = scriptingObjectType.Resolve().GetMethod("TryFind", 2);
il.Emit(OpCodes.Call, module.ImportReference(tryFind));
il.Emit(OpCodes.Castclass, valueType);
il.Emit(OpCodes.Stloc, localIndex);
}
}
else if (valueTypeDef.IsEnum)
{
// Replicate enum as bits
// TODO: use smaller uint depending on enum values range
if (serialize)
{
il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, streamLocalIndex));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldarg, localIndex));
il.InsertBefore(ilStart, il.Create(OpCodes.Callvirt, module.ImportReference(networkStreamType.GetMethod("WriteUInt32"))));
}
else
{
il.Emit(OpCodes.Ldloc_1);
var m = networkStreamType.GetMethod("ReadUInt32");
il.Emit(OpCodes.Callvirt, module.ImportReference(m));
il.Emit(OpCodes.Stloc, localIndex);
}
}
else if (valueType.IsValueType)
{
// Invoke structure generated serializer
// TODO: check if this type has generated serialization code
if (serialize)
{
il.InsertBefore(ilStart, il.Create(OpCodes.Nop));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldarga, localIndex));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, streamLocalIndex));
il.InsertBefore(ilStart, il.Create(OpCodes.Call, module.ImportReference(valueTypeDef.GetMethod(Thunk1))));
}
else
{
il.Emit(OpCodes.Ldloca_S, (byte)localIndex);
il.Emit(OpCodes.Initobj, valueType);
il.Emit(OpCodes.Ldloca_S, (byte)localIndex);
il.Emit(OpCodes.Ldloc_1);
il.Emit(OpCodes.Call, module.ImportReference(valueTypeDef.GetMethod(Thunk2)));
}
}
else
{
// Unknown type
Log.Error($"Not supported type '{valueType.FullName}' for RPC parameter in {type.FullName}.");
failed = true;
}
}
private static void GenerateDotNetRPCBody(TypeDefinition type, MethodDefinition method, CustomAttribute attribute, ref bool failed, TypeReference networkStreamType, List<MethodRPC> methodRPCs)
{
// Validate RPC usage
if (method.IsAbstract)
{
Log.Error($"Not supported abstract RPC method '{method.FullName}'.");
failed = true;
return;
}
if (method.IsVirtual)
{
Log.Error($"Not supported virtual RPC method '{method.FullName}'.");
failed = true;
return;
}
ModuleDefinition module = type.Module;
var voidType = module.TypeSystem.Void;
if (method.ReturnType != voidType)
{
Log.Error($"Not supported non-void RPC method '{method.FullName}'.");
failed = true;
return;
}
if (method.IsStatic)
{
Log.Error($"Not supported static RPC method '{method.FullName}'.");
failed = true;
return;
}
var methodRPC = new MethodRPC();
methodRPC.Type = type;
methodRPC.Method = method;
methodRPC.IsServer = (bool)attribute.GetFieldValue("Server", false);
methodRPC.IsClient = (bool)attribute.GetFieldValue("Client", false);
if (methodRPC.IsServer && methodRPC.IsClient)
{
Log.Error($"Network RPC {method.Name} in {type.FullName} cannot be both Server and Client.");
failed = true;
return;
}
if (!methodRPC.IsServer && !methodRPC.IsClient)
{
Log.Error($"Network RPC {method.Name} in {type.FullName} needs to have Server or Client specifier.");
failed = true;
return;
}
methodRPC.Channel = (int)attribute.GetFieldValue("Channel", 4); // int as NetworkChannelType (default is ReliableOrdered=4)
module.GetType("System.IntPtr", out var intPtrType);
module.GetType("FlaxEngine.Object", out var scriptingObjectType);
var fromUnmanagedPtr = scriptingObjectType.Resolve().GetMethod("FromUnmanagedPtr");
TypeReference networkStream = module.ImportReference(networkStreamType);
// Generate static method to execute RPC locally
{
var m = new MethodDefinition(method.Name + "_Execute", MethodAttributes.Static | MethodAttributes.Private | MethodAttributes.HideBySig, voidType);
m.Parameters.Add(new ParameterDefinition("instancePtr", ParameterAttributes.None, intPtrType));
m.Parameters.Add(new ParameterDefinition("streamPtr", ParameterAttributes.None, module.ImportReference(intPtrType)));
ILProcessor il = m.Body.GetILProcessor();
il.Emit(OpCodes.Nop);
il.Body.InitLocals = true;
// <type> instance = (<type>)FlaxEngine.Object.FromUnmanagedPtr(instancePtr)
il.Body.Variables.Add(new VariableDefinition(type));
il.Emit(OpCodes.Ldarg_0);
il.Emit(OpCodes.Call, module.ImportReference(fromUnmanagedPtr));
il.Emit(OpCodes.Castclass, type);
il.Emit(OpCodes.Stloc_0);
// NetworkStream stream = (NetworkStream)FlaxEngine.Object.FromUnmanagedPtr(streamPtr)
il.Body.Variables.Add(new VariableDefinition(networkStream));
il.Emit(OpCodes.Ldarg_1);
il.Emit(OpCodes.Call, module.ImportReference(fromUnmanagedPtr));
il.Emit(OpCodes.Castclass, networkStream);
il.Emit(OpCodes.Stloc_1);
// Add locals for each RPC parameter
var argsStart = il.Body.Variables.Count;
for (int i = 0; i < method.Parameters.Count; i++)
{
var parameter = method.Parameters[i];
if (parameter.IsOut)
{
Log.Error($"Network RPC {method.Name} in {type.FullName} parameter {parameter.Name} cannot be 'out'.");
failed = true;
return;
}
var parameterType = parameter.ParameterType;
il.Body.Variables.Add(new VariableDefinition(parameterType));
}
// Deserialize parameters from the stream
for (int i = 0; i < method.Parameters.Count; i++)
{
var parameter = method.Parameters[i];
var parameterType = parameter.ParameterType;
GenerateDotNetRPCSerializerType(type, false, ref failed, argsStart + i, parameterType, il, networkStream.Resolve(), 1, null);
}
// Call RPC method body
il.Emit(OpCodes.Ldloc_0);
for (int i = 0; i < method.Parameters.Count; i++)
{
il.Emit(OpCodes.Ldloc, argsStart + i);
}
il.Emit(OpCodes.Callvirt, method);
il.Emit(OpCodes.Nop);
il.Emit(OpCodes.Ret);
type.Methods.Add(m);
methodRPC.Execute = m;
}
// Inject custom code before RPC method body to invoke it
{
ILProcessor il = method.Body.GetILProcessor();
Instruction ilStart = il.Body.Instructions[0];
module.GetType("System.Boolean", out var boolType);
module.GetType("FlaxEngine.Networking.NetworkManagerMode", out var networkManagerModeType);
module.GetType("FlaxEngine.Networking.NetworkManager", out var networkManagerType);
var networkManagerGetMode = networkManagerType.Resolve().GetMethod("get_Mode", 0);
il.Body.InitLocals = true;
var varsStart = il.Body.Variables.Count;
// Is Server/Is Client boolean constants
il.Body.Variables.Add(new VariableDefinition(module.ImportReference(boolType))); // [0]
il.Body.Variables.Add(new VariableDefinition(module.ImportReference(boolType))); // [1]
il.InsertBefore(ilStart, il.Create(OpCodes.Ldc_I4, methodRPC.IsServer ? 1 : 0));
il.InsertBefore(ilStart, il.Create(OpCodes.Stloc, varsStart + 0)); // isServer loc=0
il.InsertBefore(ilStart, il.Create(OpCodes.Ldc_I4, methodRPC.IsClient ? 1 : 0));
il.InsertBefore(ilStart, il.Create(OpCodes.Stloc, varsStart + 1)); // isClient loc=1
// NetworkManagerMode mode = NetworkManager.Mode;
il.Body.Variables.Add(new VariableDefinition(module.ImportReference(networkManagerModeType))); // [2]
il.InsertBefore(ilStart, il.Create(OpCodes.Call, module.ImportReference(networkManagerGetMode)));
il.InsertBefore(ilStart, il.Create(OpCodes.Stloc, varsStart + 2)); // mode loc=2
// if ((server && networkMode == NetworkManagerMode.Client) || (client && networkMode != NetworkManagerMode.Client))
var jumpIfBodyStart = il.Create(OpCodes.Nop); // if block body
var jumpIf2Start = il.Create(OpCodes.Nop); // 2nd part of the if
var jumpBodyStart = il.Create(OpCodes.Nop); // original method body start
var jumpBodyEnd = il.Body.Instructions.First(x => x.OpCode == OpCodes.Ret);
il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, varsStart + 0));
il.InsertBefore(ilStart, il.Create(OpCodes.Brfalse_S, jumpIf2Start));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, varsStart + 2));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldc_I4_2));
il.InsertBefore(ilStart, il.Create(OpCodes.Beq_S, jumpIfBodyStart));
// ||
il.InsertBefore(ilStart, jumpIf2Start);
il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, varsStart + 1));
il.InsertBefore(ilStart, il.Create(OpCodes.Brfalse_S, jumpBodyStart));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, varsStart + 2));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldc_I4_2));
il.InsertBefore(ilStart, il.Create(OpCodes.Beq_S, jumpBodyStart));
// {
il.InsertBefore(ilStart, jumpIfBodyStart);
// NetworkStream stream = NetworkReplicator.BeginInvokeRPC();
il.Body.Variables.Add(new VariableDefinition(module.ImportReference(networkStream))); // [3]
var streamLocalIndex = varsStart + 3;
module.GetType("FlaxEngine.Networking.NetworkReplicator", out var networkReplicatorType);
var beginInvokeRPC = networkReplicatorType.Resolve().GetMethod("BeginInvokeRPC", 0);
il.InsertBefore(ilStart, il.Create(OpCodes.Call, module.ImportReference(beginInvokeRPC)));
il.InsertBefore(ilStart, il.Create(OpCodes.Stloc, streamLocalIndex)); // stream loc=3
// Serialize all RPC parameters
for (int i = 0; i < method.Parameters.Count; i++)
{
var parameter = method.Parameters[i];
var parameterType = parameter.ParameterType;
GenerateDotNetRPCSerializerType(type, true, ref failed, i + 1, parameterType, il, networkStream.Resolve(), streamLocalIndex, ilStart);
}
// NetworkReplicator.EndInvokeRPC(this, typeof(<type>), "<name>", stream);
il.InsertBefore(ilStart, il.Create(OpCodes.Nop));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldarg_0));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldtoken, type));
module.GetType("System.Type", out var typeType);
var getTypeFromHandle = typeType.Resolve().GetMethod("GetTypeFromHandle");
il.InsertBefore(ilStart, il.Create(OpCodes.Call, module.ImportReference(getTypeFromHandle)));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldstr, method.Name));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, streamLocalIndex));
var endInvokeRPC = networkReplicatorType.Resolve().GetMethod("EndInvokeRPC", 4);
il.InsertBefore(ilStart, il.Create(OpCodes.Call, module.ImportReference(endInvokeRPC)));
// if (server && networkMode == NetworkManagerMode.Client) return;
if (methodRPC.IsServer)
{
il.InsertBefore(ilStart, il.Create(OpCodes.Nop));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, varsStart + 2));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldc_I4_2));
il.InsertBefore(ilStart, il.Create(OpCodes.Beq_S, jumpBodyEnd));
}
// if (client && networkMode == NetworkManagerMode.Server) return;
if (methodRPC.IsClient)
{
il.InsertBefore(ilStart, il.Create(OpCodes.Nop));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldloc, varsStart + 2));
il.InsertBefore(ilStart, il.Create(OpCodes.Ldc_I4_1));
il.InsertBefore(ilStart, il.Create(OpCodes.Beq_S, jumpBodyEnd));
}
// Continue to original method body
il.InsertBefore(ilStart, jumpBodyStart);
}
methodRPCs.Add(methodRPC);
}
}
}