From 7950c0d9d76dac13c99e2fa929063bd6290240cb Mon Sep 17 00:00:00 2001 From: Wojtek Figat Date: Tue, 14 Mar 2023 00:02:24 +0100 Subject: [PATCH] Fix codegen for C# networking when using custom structures for replication and RPCs --- .../Build/Plugins/NetworkingPlugin.cs | 310 +++++++++++------- 1 file changed, 183 insertions(+), 127 deletions(-) diff --git a/Source/Tools/Flax.Build/Build/Plugins/NetworkingPlugin.cs b/Source/Tools/Flax.Build/Build/Plugins/NetworkingPlugin.cs index 99ded7da1..40f86ef66 100644 --- a/Source/Tools/Flax.Build/Build/Plugins/NetworkingPlugin.cs +++ b/Source/Tools/Flax.Build/Build/Plugins/NetworkingPlugin.cs @@ -47,6 +47,18 @@ namespace Flax.Build.Plugins public MethodDefinition Execute; } + private struct DotnetContext + { + public bool Modified; + public bool Failed; + public AssemblyDefinition Assembly; + public List AddSerializers; + public List MethodRPCs; + public HashSet GeneratedSerializers; + public TypeReference VoidType; + public TypeReference NetworkStreamType; + } + internal const string Network = "Network"; internal const string NetworkReplicated = "NetworkReplicated"; internal const string NetworkReplicatedAttribute = "FlaxEngine.NetworkReplicatedAttribute"; @@ -494,7 +506,7 @@ namespace Flax.Build.Plugins task.WorkingDirectory = buildTask.WorkingDirectory; task.Command = () => OnPatchDotNetAssembly(buildData, buildOptions, buildTask, assemblyPath); task.CommandPath = null; - task.InfoMessage = $"Generating netowrking code for {Path.GetFileName(assemblyPath)}..."; + task.InfoMessage = $"Generating networking code for {Path.GetFileName(assemblyPath)}..."; task.Cost = 50; task.DisableCache = true; task.DependentTasks = new HashSet(); @@ -519,109 +531,31 @@ namespace Flax.Build.Plugins assemblyResolver.AddSearchDirectory(e); ModuleDefinition module = assembly.MainModule; - TypeReference voidType = module.ImportReference(typeof(void)); - module.GetType("FlaxEngine.Networking.NetworkStream", out var networkStreamType); // Process all types within a module - bool modified = false; - bool failed = false; - var addSerializers = new List(); - var methodRPCs = new List(); + var context = new DotnetContext + { + Modified = false, + Failed = false, + Assembly = assembly, + AddSerializers = new List(), + MethodRPCs = new List(), + GeneratedSerializers = new HashSet(), + VoidType = module.ImportReference(typeof(void)), + }; + module.GetType("FlaxEngine.Networking.NetworkStream", out context.NetworkStreamType); foreach (TypeDefinition type in module.Types) { - if (type.IsInterface || type.IsEnum) - continue; - - // Generate RPCs - var methods = type.Methods; - var methodsCount = methods.Count; // methods list can be modified during RPCs generation - for (int i = 0; i < methodsCount; i++) - { - MethodDefinition method = methods[i]; - var attribute = method.CustomAttributes.FirstOrDefault(x => x.AttributeType.FullName == "FlaxEngine.NetworkRpcAttribute"); - if (attribute != null) - { - GenerateDotNetRPCBody(type, method, attribute, ref failed, networkStreamType, methodRPCs); - modified = true; - } - } - - var isNative = type.HasAttribute("FlaxEngine.UnmanagedAttribute"); - if (isNative) - continue; - - // Generate serializers - if (type.HasMethod(Thunk1) || type.HasMethod(Thunk2)) - continue; - var isINetworkSerializable = type.HasInterface("FlaxEngine.Networking.INetworkSerializable"); - MethodDefinition serializeINetworkSerializable = null, deserializeINetworkSerializable = null; - if (isINetworkSerializable) - { - foreach (MethodDefinition m in type.Methods) - { - if (m.HasBody && m.Parameters.Count == 1 && m.Parameters[0].ParameterType.FullName == "FlaxEngine.Networking.NetworkStream") - { - if (m.Name == "Serialize") - serializeINetworkSerializable = m; - else if (m.Name == "Deserialize") - deserializeINetworkSerializable = m; - } - } - } - - var isNetworkReplicated = false; - foreach (FieldDefinition f in type.Fields) - { - if (!f.HasAttribute(NetworkReplicatedAttribute)) - continue; - isNetworkReplicated = true; - break; - } - - foreach (PropertyDefinition p in type.Properties) - { - if (!p.HasAttribute(NetworkReplicatedAttribute)) - continue; - isNetworkReplicated = true; - break; - } - - if (type.IsValueType) - { - if (isINetworkSerializable) - { - // Generate INetworkSerializable interface method calls - GenerateCallINetworkSerializable(type, Thunk1, voidType, networkStreamType, serializeINetworkSerializable); - GenerateCallINetworkSerializable(type, Thunk2, voidType, networkStreamType, deserializeINetworkSerializable); - modified = true; - } - else if (isNetworkReplicated) - { - // Generate serializization methods - GenerateSerializer(type, true, ref failed, Thunk1, voidType, networkStreamType); - GenerateSerializer(type, false, ref failed, Thunk2, voidType, networkStreamType); - modified = true; - } - } - else if (!isINetworkSerializable && isNetworkReplicated) - { - // Generate serializization methods - var addSerializer = new TypeSerializer(); - addSerializer.Type = type; - addSerializer.Serialize = GenerateNativeSerializer(type, true, ref failed, Thunk1, voidType, networkStreamType); - addSerializer.Deserialize = GenerateNativeSerializer(type, false, ref failed, Thunk2, voidType, networkStreamType); - addSerializers.Add(addSerializer); - modified = true; - } + GenerateTypeNetworking(ref context, type); } - if (failed) + if (context.Failed) throw new Exception($"Failed to generate network replication for assembly {assemblyPath}"); - if (!modified) + if (!context.Modified) return; // Generate serializers initializer (invoked on module load) - if (addSerializers.Count != 0 || methodRPCs.Count != 0) + if (context.AddSerializers.Count != 0 || context.MethodRPCs.Count != 0) { // Create class var name = "Initializer"; @@ -641,7 +575,7 @@ namespace Flax.Build.Plugins c.CustomAttributes.Add(attribute); // Add Init method - var m = new MethodDefinition("Init", MethodAttributes.Private | MethodAttributes.Static | MethodAttributes.HideBySig, voidType); + var m = new MethodDefinition("Init", MethodAttributes.Private | MethodAttributes.Static | MethodAttributes.HideBySig, context.VoidType); ILProcessor il = m.Body.GetILProcessor(); il.Emit(OpCodes.Nop); module.GetType("System.Type", out var typeType); @@ -655,7 +589,7 @@ namespace Flax.Build.Plugins module.ImportReference(addRPC); var executeRPCFuncType = addRPC.Parameters[2].ParameterType; var executeRPCFuncCtor = executeRPCFuncType.Resolve().GetMethod(".ctor"); - foreach (var e in addSerializers) + foreach (var e in context.AddSerializers) { // NetworkReplicator.AddSerializer(typeof(), .INetworkSerializable_SerializeNative, .INetworkSerializable_DeserializeNative); il.Emit(OpCodes.Ldtoken, e.Type); @@ -669,7 +603,7 @@ namespace Flax.Build.Plugins il.Emit(OpCodes.Call, module.ImportReference(addSerializer)); } - foreach (var e in methodRPCs) + foreach (var e in context.MethodRPCs) { // NetworkReplicator.AddRPC(typeof(), "", _Execute, , , ); il.Emit(OpCodes.Ldtoken, e.Type); @@ -694,10 +628,124 @@ namespace Flax.Build.Plugins } } - private static void GenerateCallINetworkSerializable(TypeDefinition type, string name, TypeReference voidType, TypeReference networkStreamType, MethodDefinition method) + private static void GenerateTypeNetworking(ref DotnetContext context, TypeDefinition type) { - var m = new MethodDefinition(name, MethodAttributes.Public | MethodAttributes.HideBySig, voidType); - m.Parameters.Add(new ParameterDefinition("stream", ParameterAttributes.None, networkStreamType)); + if (type.IsInterface || type.IsEnum) + return; + + // Process nested types + foreach (var nestedType in type.NestedTypes) + { + GenerateTypeNetworking(ref context, nestedType); + } + + if (type.IsClass) + { + // Generate RPCs + var methods = type.Methods; + var methodsCount = methods.Count; // methods list can be modified during RPCs generation + for (int i = 0; i < methodsCount; i++) + { + MethodDefinition method = methods[i]; + var attribute = method.CustomAttributes.FirstOrDefault(x => x.AttributeType.FullName == "FlaxEngine.NetworkRpcAttribute"); + if (attribute != null) + { + GenerateDotNetRPCBody(ref context, type, method, attribute, context.NetworkStreamType); + context.Modified = true; + } + } + } + + GenerateTypeSerialization(ref context, type); + } + + private static void GenerateTypeSerialization(ref DotnetContext context, TypeDefinition type) + { + // Skip types outside from current assembly + if (context.Assembly.MainModule != type.Module) + return; + + // Skip if already generated serialization for this type (eg. via referenced RPC in other type) + if (context.GeneratedSerializers.Contains(type)) + return; + context.GeneratedSerializers.Add(type); + + // Skip native types + var isNative = type.HasAttribute("FlaxEngine.UnmanagedAttribute"); + if (isNative) + return; + + // Skip if manually implemented serializers + if (type.HasMethod(Thunk1) || type.HasMethod(Thunk2)) + return; + + // Generate serializers + var isINetworkSerializable = type.HasInterface("FlaxEngine.Networking.INetworkSerializable"); + MethodDefinition serializeINetworkSerializable = null, deserializeINetworkSerializable = null; + if (isINetworkSerializable) + { + foreach (MethodDefinition m in type.Methods) + { + if (m.HasBody && m.Parameters.Count == 1 && m.Parameters[0].ParameterType.FullName == "FlaxEngine.Networking.NetworkStream") + { + if (m.Name == "Serialize") + serializeINetworkSerializable = m; + else if (m.Name == "Deserialize") + deserializeINetworkSerializable = m; + } + } + } + + var isNetworkReplicated = false; + foreach (FieldDefinition f in type.Fields) + { + if (!f.HasAttribute(NetworkReplicatedAttribute)) + continue; + isNetworkReplicated = true; + break; + } + + foreach (PropertyDefinition p in type.Properties) + { + if (!p.HasAttribute(NetworkReplicatedAttribute)) + continue; + isNetworkReplicated = true; + break; + } + + if (type.IsValueType) + { + if (isINetworkSerializable) + { + // Generate INetworkSerializable interface method calls + GenerateCallINetworkSerializable(ref context, type, Thunk1, serializeINetworkSerializable); + GenerateCallINetworkSerializable(ref context, type, Thunk2, deserializeINetworkSerializable); + context.Modified = true; + } + else if (isNetworkReplicated) + { + // Generate serializization methods + GenerateSerializer(ref context, type, true, Thunk1); + GenerateSerializer(ref context, type, false, Thunk2); + context.Modified = true; + } + } + else if (!isINetworkSerializable && isNetworkReplicated) + { + // Generate serializization methods + var addSerializer = new TypeSerializer(); + addSerializer.Type = type; + addSerializer.Serialize = GenerateNativeSerializer(ref context, type, true, Thunk1); + addSerializer.Deserialize = GenerateNativeSerializer(ref context, type, false, Thunk2); + context.AddSerializers.Add(addSerializer); + context.Modified = true; + } + } + + private static void GenerateCallINetworkSerializable(ref DotnetContext context, TypeDefinition type, string name, MethodDefinition method) + { + var m = new MethodDefinition(name, MethodAttributes.Public | MethodAttributes.HideBySig, context.VoidType); + m.Parameters.Add(new ParameterDefinition("stream", ParameterAttributes.None, context.NetworkStreamType)); ILProcessor il = m.Body.GetILProcessor(); il.Emit(OpCodes.Nop); il.Emit(OpCodes.Ldarg_0); @@ -708,12 +756,11 @@ namespace Flax.Build.Plugins type.Methods.Add(m); } - private static MethodDefinition GenerateSerializer(TypeDefinition type, bool serialize, ref bool failed, string name, TypeReference voidType, TypeReference networkStreamType) + private static MethodDefinition GenerateSerializer(ref DotnetContext context, TypeDefinition type, bool serialize, string name) { ModuleDefinition module = type.Module; - var m = new MethodDefinition(name, MethodAttributes.Public | MethodAttributes.HideBySig, voidType); - m.Parameters.Add(new ParameterDefinition("stream", ParameterAttributes.None, module.ImportReference(networkStreamType))); - TypeDefinition networkStream = networkStreamType.Resolve(); + var m = new MethodDefinition(name, MethodAttributes.Public | MethodAttributes.HideBySig, context.VoidType); + m.Parameters.Add(new ParameterDefinition("stream", ParameterAttributes.None, module.ImportReference(context.NetworkStreamType))); ILProcessor il = m.Body.GetILProcessor(); il.Emit(OpCodes.Nop); @@ -728,7 +775,7 @@ namespace Flax.Build.Plugins { if (!f.HasAttribute(NetworkReplicatedAttribute)) continue; - GenerateSerializerType(type, serialize, ref failed, f, null, f.FieldType, il, networkStream); + GenerateSerializerType(ref context, type, serialize, f, null, f.FieldType, il); } // Serialize all type properties marked with NetworkReplicated attribute @@ -736,7 +783,7 @@ namespace Flax.Build.Plugins { if (!p.HasAttribute(NetworkReplicatedAttribute)) continue; - GenerateSerializerType(type, serialize, ref failed, null, p, p.PropertyType, il, networkStream); + GenerateSerializerType(ref context, type, serialize, null, p, p.PropertyType, il); } if (serialize) @@ -746,17 +793,17 @@ namespace Flax.Build.Plugins return m; } - private static MethodDefinition GenerateNativeSerializer(TypeDefinition type, bool serialize, ref bool failed, string name, TypeReference voidType, TypeReference networkStreamType) + private static MethodDefinition GenerateNativeSerializer(ref DotnetContext context, TypeDefinition type, bool serialize, string name) { ModuleDefinition module = type.Module; module.GetType("System.IntPtr", out var intPtrType); module.GetType("FlaxEngine.Object", out var scriptingObjectType); var fromUnmanagedPtr = scriptingObjectType.Resolve().GetMethod("FromUnmanagedPtr"); - var m = new MethodDefinition(name + "Native", MethodAttributes.Public | MethodAttributes.Static | MethodAttributes.HideBySig, voidType); + var m = new MethodDefinition(name + "Native", MethodAttributes.Public | MethodAttributes.Static | MethodAttributes.HideBySig, context.VoidType); m.Parameters.Add(new ParameterDefinition("instancePtr", ParameterAttributes.None, intPtrType)); m.Parameters.Add(new ParameterDefinition("streamPtr", ParameterAttributes.None, intPtrType)); - TypeReference networkStream = module.ImportReference(networkStreamType); + TypeReference networkStream = module.ImportReference(context.NetworkStreamType); ILProcessor il = m.Body.GetILProcessor(); il.Emit(OpCodes.Nop); il.Body.InitLocals = true; @@ -776,7 +823,7 @@ namespace Flax.Build.Plugins il.Emit(OpCodes.Stloc_1); // Generate normal serializer - var serializer = GenerateSerializer(type, serialize, ref failed, name, voidType, networkStreamType); + var serializer = GenerateSerializer(ref context, type, serialize, name); // Call serializer il.Emit(OpCodes.Ldloc_0); @@ -816,10 +863,11 @@ namespace Flax.Build.Plugins } } - private static void GenerateSerializerType(TypeDefinition type, bool serialize, ref bool failed, FieldReference field, PropertyDefinition property, TypeReference valueType, ILProcessor il, TypeDefinition networkStreamType) + private static void GenerateSerializerType(ref DotnetContext context, TypeDefinition type, bool serialize, FieldReference field, PropertyDefinition property, TypeReference valueType, ILProcessor il) { if (field == null && property == null) throw new ArgumentException(); + TypeDefinition networkStreamType = context.NetworkStreamType.Resolve(); var propertyGetOpCode = OpCodes.Call; var propertySetOpCode = OpCodes.Call; if (property != null) @@ -827,14 +875,14 @@ namespace Flax.Build.Plugins if (property.GetMethod == null) { MonoCecil.CompilationError($"Missing getter method for property '{property.Name}' of type {valueType.FullName} in {type.FullName} for automatic replication.", property); - failed = true; + context.Failed = true; return; } if (property.SetMethod == null) { MonoCecil.CompilationError($"Missing setter method for property '{property.Name}' of type {valueType.FullName} in {type.FullName} for automatic replication.", property); - failed = true; + context.Failed = true; return; } @@ -846,6 +894,10 @@ namespace Flax.Build.Plugins ModuleDefinition module = type.Module; TypeDefinition valueTypeDef = valueType.Resolve(); + + // Ensure to have valid serialization already generated for that value type (eg. when using custom structure field serialization) + GenerateTypeSerialization(ref context, valueTypeDef); + if (_inBuildSerializers.TryGetValue(valueType.FullName, out var serializer)) { // Call NetworkStream method to write/read data @@ -1121,14 +1173,18 @@ namespace Flax.Build.Plugins MonoCecil.CompilationError($"Not supported type '{valueType.FullName}' on {field.Name} in {type.FullName} for automatic replication.", field.Resolve()); else MonoCecil.CompilationError($"Not supported type '{valueType.FullName}' for automatic replication."); - failed = true; + context.Failed = true; } } - private static void GenerateDotNetRPCSerializerType(TypeDefinition type, bool serialize, ref bool failed, int localIndex, TypeReference valueType, ILProcessor il, TypeDefinition networkStreamType, int streamLocalIndex, Instruction ilStart) + private static void GenerateDotNetRPCSerializerType(ref DotnetContext context, TypeDefinition type, bool serialize, int localIndex, TypeReference valueType, ILProcessor il, TypeDefinition networkStreamType, int streamLocalIndex, Instruction ilStart) { ModuleDefinition module = type.Module; TypeDefinition valueTypeDef = valueType.Resolve(); + + // Ensure to have valid serialization already generated for that value type + GenerateTypeSerialization(ref context, valueTypeDef); + if (_inBuildSerializers.TryGetValue(valueType.FullName, out var serializer)) { // Call NetworkStream method to write/read data @@ -1233,24 +1289,24 @@ namespace Flax.Build.Plugins { // Unknown type Log.Error($"Not supported type '{valueType.FullName}' for RPC parameter in {type.FullName}."); - failed = true; + context.Failed = true; } } - private static void GenerateDotNetRPCBody(TypeDefinition type, MethodDefinition method, CustomAttribute attribute, ref bool failed, TypeReference networkStreamType, List methodRPCs) + private static void GenerateDotNetRPCBody(ref DotnetContext context, TypeDefinition type, MethodDefinition method, CustomAttribute attribute, TypeReference networkStreamType) { // Validate RPC usage if (method.IsAbstract) { MonoCecil.CompilationError($"Not supported abstract RPC method '{method.FullName}'.", method); - failed = true; + context.Failed = true; return; } if (method.IsVirtual) { MonoCecil.CompilationError($"Not supported virtual RPC method '{method.FullName}'.", method); - failed = true; + context.Failed = true; return; } @@ -1259,14 +1315,14 @@ namespace Flax.Build.Plugins if (method.ReturnType != voidType) { MonoCecil.CompilationError($"Not supported non-void RPC method '{method.FullName}'.", method); - failed = true; + context.Failed = true; return; } if (method.IsStatic) { MonoCecil.CompilationError($"Not supported static RPC method '{method.FullName}'.", method); - failed = true; + context.Failed = true; return; } @@ -1287,14 +1343,14 @@ namespace Flax.Build.Plugins if (methodRPC.IsServer && methodRPC.IsClient) { MonoCecil.CompilationError($"Network RPC {method.Name} in {type.FullName} cannot be both Server and Client.", method); - failed = true; + context.Failed = true; return; } if (!methodRPC.IsServer && !methodRPC.IsClient) { MonoCecil.CompilationError($"Network RPC {method.Name} in {type.FullName} needs to have Server or Client specifier.", method); - failed = true; + context.Failed = true; return; } @@ -1334,7 +1390,7 @@ namespace Flax.Build.Plugins if (parameter.IsOut) { MonoCecil.CompilationError($"Network RPC {method.Name} in {type.FullName} parameter {parameter.Name} cannot be 'out'.", method); - failed = true; + context.Failed = true; return; } @@ -1347,7 +1403,7 @@ namespace Flax.Build.Plugins { var parameter = method.Parameters[i]; var parameterType = parameter.ParameterType; - GenerateDotNetRPCSerializerType(type, false, ref failed, argsStart + i, parameterType, il, networkStream.Resolve(), 1, null); + GenerateDotNetRPCSerializerType(ref context, type, false, argsStart + i, parameterType, il, networkStream.Resolve(), 1, null); } // Call RPC method body @@ -1424,7 +1480,7 @@ namespace Flax.Build.Plugins { var parameter = method.Parameters[i]; var parameterType = parameter.ParameterType; - GenerateDotNetRPCSerializerType(type, true, ref failed, i + 1, parameterType, il, networkStream.Resolve(), streamLocalIndex, ilStart); + GenerateDotNetRPCSerializerType(ref context, type, true, i + 1, parameterType, il, networkStream.Resolve(), streamLocalIndex, ilStart); } // NetworkReplicator.EndInvokeRPC(this, typeof(), "", stream); @@ -1471,7 +1527,7 @@ namespace Flax.Build.Plugins il.InsertBefore(ilStart, jumpBodyStart); } - methodRPCs.Add(methodRPC); + context.MethodRPCs.Add(methodRPC); } } }