Add network RPCs

This commit is contained in:
Wojciech Figat
2022-11-16 14:25:12 +01:00
parent 91ff0f76f8
commit efb48697fa
9 changed files with 560 additions and 37 deletions

View File

@@ -36,8 +36,10 @@ namespace Flax.Build.Plugins
public MethodDefinition Serialize;
public MethodDefinition Deserialize;
}
internal const string Network = "Network";
internal const string NetworkReplicated = "NetworkReplicated";
internal const string NetworkRpc = "NetworkRpc";
private const string Thunk1 = "INetworkSerializable_Serialize";
private const string Thunk2 = "INetworkSerializable_Deserialize";
private static readonly Dictionary<string, InBuildSerializer> _inBuildSerializers = new Dictionary<string, InBuildSerializer>()
@@ -78,34 +80,42 @@ namespace Flax.Build.Plugins
private void OnParseMemberTag(ref bool valid, BindingsGenerator.TagParameter tag, MemberInfo memberInfo)
{
if (tag.Tag != NetworkReplicated)
return;
// Mark member as replicated
valid = true;
memberInfo.SetTag(NetworkReplicated, string.Empty);
if (tag.Tag == NetworkReplicated)
{
// Mark member as replicated
valid = true;
memberInfo.SetTag(NetworkReplicated, string.Empty);
}
else if (tag.Tag == NetworkRpc)
{
// Mark member as rpc
valid = true;
memberInfo.SetTag(NetworkRpc, tag.Value);
}
}
private void OnGenerateCppTypeInternals(Builder.BuildData buildData, ApiTypeInfo typeInfo, StringBuilder contents)
{
// Skip modules that don't use networking
var module = BindingsGenerator.CurrentModule;
if (module.GetTag(NetworkReplicated) == null)
if (module.GetTag(Network) == null)
return;
// Check if type uses automated network replication
// Check if type uses automated network replication/RPCs
List<FieldInfo> fields = null;
List<PropertyInfo> properties = null;
List<FunctionInfo> functions = null;
if (typeInfo is ClassInfo classInfo)
{
fields = classInfo.Fields;
properties = classInfo.Properties;
functions = classInfo.Functions;
}
else if (typeInfo is StructureInfo structInfo)
{
fields = structInfo.Fields;
}
bool useReplication = false;
bool useReplication = false, useRpc = false;
if (fields != null)
{
foreach (var fieldInfo in fields)
@@ -128,16 +138,117 @@ namespace Flax.Build.Plugins
}
}
}
if (!useReplication)
return;
typeInfo.SetTag(NetworkReplicated, string.Empty);
if (functions != null)
{
foreach (var functionInfo in functions)
{
if (functionInfo.GetTag(NetworkRpc) != null)
{
useRpc = true;
break;
}
}
}
if (useReplication)
{
typeInfo.SetTag(NetworkReplicated, string.Empty);
// Generate C++ wrapper functions to serialize/deserialize type
BindingsGenerator.CppIncludeFiles.Add("Engine/Networking/NetworkReplicator.h");
BindingsGenerator.CppIncludeFiles.Add("Engine/Networking/NetworkStream.h");
OnGenerateCppTypeSerialize(buildData, typeInfo, contents, fields, properties, true);
OnGenerateCppTypeSerialize(buildData, typeInfo, contents, fields, properties, false);
// Generate C++ wrapper functions to serialize/deserialize type
BindingsGenerator.CppIncludeFiles.Add("Engine/Networking/NetworkReplicator.h");
BindingsGenerator.CppIncludeFiles.Add("Engine/Networking/NetworkStream.h");
OnGenerateCppTypeSerialize(buildData, typeInfo, contents, fields, properties, true);
OnGenerateCppTypeSerialize(buildData, typeInfo, contents, fields, properties, false);
}
if (useRpc)
{
typeInfo.SetTag(NetworkRpc, string.Empty);
// Generate C++ wrapper functions to invoke/execute RPC
BindingsGenerator.CppIncludeFiles.Add("Engine/Networking/NetworkStream.h");
BindingsGenerator.CppIncludeFiles.Add("Engine/Networking/NetworkReplicator.h");
BindingsGenerator.CppIncludeFiles.Add("Engine/Networking/NetworkChannelType.h");
BindingsGenerator.CppIncludeFiles.Add("Engine/Networking/NetworkRpc.h");
foreach (var functionInfo in functions)
{
var tag = functionInfo.GetTag(NetworkRpc);
if (tag == null)
continue;
if (functionInfo.UniqueName != functionInfo.Name)
throw new Exception($"Invalid network RPC method {functionInfo.Name} name in type {typeInfo.Name}. Network RPC functions names have to be unique.");
bool isServer = tag.IndexOf("Server", StringComparison.OrdinalIgnoreCase) != -1;
bool isClient = tag.IndexOf("Client", StringComparison.OrdinalIgnoreCase) != -1;
if (isServer && isClient)
throw new Exception($"Network RPC {functionInfo.Name} in {typeInfo.Name} cannot be both Server and Client.");
if (!isServer && !isClient)
throw new Exception($"Network RPC {functionInfo.Name} in {typeInfo.Name} needs to have Server or Client specifier.");
var channelType = "ReliableOrdered";
if (tag.IndexOf("UnreliableOrdered", StringComparison.OrdinalIgnoreCase) != -1)
channelType = "UnreliableOrdered";
else if (tag.IndexOf("ReliableOrdered", StringComparison.OrdinalIgnoreCase) != -1)
channelType = "ReliableOrdered";
else if (tag.IndexOf("Unreliable", StringComparison.OrdinalIgnoreCase) != -1)
channelType = "Unreliable";
else if (tag.IndexOf("Reliable", StringComparison.OrdinalIgnoreCase) != -1)
channelType = "Reliable";
// Generated method thunk to execute RPC from network
{
contents.Append(" static void ").Append(functionInfo.Name).AppendLine("_Execute(ScriptingObject* obj, NetworkStream* stream)");
contents.AppendLine(" {");
string argNames = string.Empty;
for (int i = 0; i < functionInfo.Parameters.Count; i++)
{
var arg = functionInfo.Parameters[i];
if (i != 0)
argNames += ", ";
argNames += arg.Name;
// Deserialize arguments
contents.AppendLine($" {arg.Type.Type} {arg.Name};");
contents.AppendLine($" stream->Read({arg.Name});");
}
// Call method locally
contents.AppendLine($" ASSERT(obj && obj->Is<{typeInfo.NativeName}>());");
contents.AppendLine($" (({typeInfo.NativeName}*)obj)->{functionInfo.Name}({argNames});");
contents.AppendLine(" }");
}
contents.AppendLine();
// Generated method thunk to invoke RPC to network
{
contents.Append(" static void ").Append(functionInfo.Name).AppendLine("_Invoke(ScriptingObject* obj, void** args)");
contents.AppendLine(" {");
contents.AppendLine(" NetworkStream* stream = NetworkReplicator::BeginInvokeRPC();");
for (int i = 0; i < functionInfo.Parameters.Count; i++)
{
var arg = functionInfo.Parameters[i];
// Serialize arguments
contents.AppendLine($" stream->Write(*({arg.Type.Type}*)args[{i}]);");
}
// Invoke RPC
contents.AppendLine($" NetworkReplicator::EndInvokeRPC(obj, {typeInfo.NativeName}::TypeInitializer, StringAnsiView(\"{functionInfo.Name}\", {functionInfo.Name.Length}), stream);");
contents.AppendLine(" }");
}
contents.AppendLine();
// Generated info about RPC implementation
{
contents.Append(" static NetworkRpcInfo ").Append(functionInfo.Name).AppendLine("_Info()");
contents.AppendLine(" {");
contents.AppendLine(" NetworkRpcInfo info;");
contents.AppendLine($" info.Server = {(isServer ? "1" : "0")};");
contents.AppendLine($" info.Execute = {functionInfo.Name}_Execute;");
contents.AppendLine($" info.Invoke = {functionInfo.Name}_Invoke;");
contents.AppendLine($" info.Channel = (uint8)NetworkChannelType::{channelType};");
contents.AppendLine(" return info;");
contents.AppendLine(" }");
}
contents.AppendLine();
}
}
}
private void OnGenerateCppTypeSerialize(Builder.BuildData buildData, ApiTypeInfo typeInfo, StringBuilder contents, List<FieldInfo> fields, List<PropertyInfo> properties, bool serialize)
@@ -265,13 +376,37 @@ namespace Flax.Build.Plugins
private void OnGenerateCppTypeInitRuntime(Builder.BuildData buildData, ApiTypeInfo typeInfo, StringBuilder contents)
{
if (typeInfo.GetTag(NetworkReplicated) == null)
// Skip types that don't use networking
var replicatedTag = typeInfo.GetTag(NetworkReplicated);
var rpcTag = typeInfo.GetTag(NetworkRpc);
if (replicatedTag == null && rpcTag == null)
return;
var typeNameNative = typeInfo.FullNameNative;
var typeNameInternal = typeInfo.FullNameNativeInternal;
// Register generated serializer functions
contents.AppendLine($" NetworkReplicator::AddSerializer(ScriptingTypeHandle({typeNameNative}::TypeInitializer), {typeNameInternal}Internal::INetworkSerializable_Serialize, {typeNameInternal}Internal::INetworkSerializable_Deserialize);");
if (replicatedTag != null)
{
// Register generated serializer functions
contents.AppendLine($" NetworkReplicator::AddSerializer(ScriptingTypeHandle({typeNameNative}::TypeInitializer), {typeNameInternal}Internal::INetworkSerializable_Serialize, {typeNameInternal}Internal::INetworkSerializable_Deserialize);");
}
if (rpcTag != null)
{
// Register generated RPCs
List<FunctionInfo> functions = null;
if (typeInfo is ClassInfo classInfo)
{
functions = classInfo.Functions;
}
if (functions != null)
{
foreach (var functionInfo in functions)
{
if (functionInfo.GetTag(NetworkRpc) == null)
continue;
contents.AppendLine($" NetworkRpcInfo::RPCsTable[NetworkRpcName({typeNameNative}::TypeInitializer, StringAnsiView(\"{functionInfo.Name}\", {functionInfo.Name.Length}))] = {functionInfo.Name}_Info();");
}
}
}
}
private void OnGenerateCSharpTypeInternals(Builder.BuildData buildData, ApiTypeInfo typeInfo, StringBuilder contents, string indent)
@@ -318,7 +453,7 @@ namespace Flax.Build.Plugins
private void OnBuildDotNetAssembly(TaskGraph graph, Builder.BuildData buildData, NativeCpp.BuildOptions buildOptions, Task buildTask, IGrouping<string, Module> binaryModule)
{
// Skip assemblies not using netowrking
if (!binaryModule.Any(module => module.Tags.ContainsKey(NetworkReplicated)))
if (!binaryModule.Any(module => module.Tags.ContainsKey(Network)))
return;
// Generate netoworking code inside assembly after it's being compiled
@@ -326,7 +461,7 @@ namespace Flax.Build.Plugins
var task = graph.Add<Task>();
task.ProducedFiles.Add(assemblyPath);
task.WorkingDirectory = buildTask.WorkingDirectory;
task.Command = () => OnPatchAssembly(buildData, buildOptions, buildTask, assemblyPath);
task.Command = () => OnPatchDotNetAssembly(buildData, buildOptions, buildTask, assemblyPath);
task.CommandPath = null;
task.InfoMessage = $"Generating netowrking code for {Path.GetFileName(assemblyPath)}...";
task.Cost = 50;
@@ -335,7 +470,7 @@ namespace Flax.Build.Plugins
task.DependentTasks.Add(buildTask);
}
private void OnPatchAssembly(Builder.BuildData buildData, NativeCpp.BuildOptions buildOptions, Task buildTask, string assemblyPath)
private void OnPatchDotNetAssembly(Builder.BuildData buildData, NativeCpp.BuildOptions buildOptions, Task buildTask, string assemblyPath)
{
using (DefaultAssemblyResolver assemblyResolver = new DefaultAssemblyResolver())
using (AssemblyDefinition assembly = AssemblyDefinition.ReadAssembly(assemblyPath, new ReaderParameters{ ReadWrite = true, ReadSymbols = true, AssemblyResolver = assemblyResolver }))
@@ -391,8 +526,6 @@ namespace Flax.Build.Plugins
isNetworkReplicated = true;
break;
}
if (type.IsValueType)
{
if (isINetworkSerializable)

View File

@@ -41,7 +41,7 @@ namespace Flax.Build
public static MethodDefinition GetMethod(this TypeDefinition type, string name, int argCount)
{
var result = type.Methods.First(x => x.Name == name && x.Parameters.Count == argCount);
var result = type.Methods.FirstOrDefault(x => x.Name == name && x.Parameters.Count == argCount);
if (result == null)
throw new Exception($"Failed to find method '{name}' (args={argCount}) in '{type.FullName}'.");
return result;
@@ -49,12 +49,32 @@ namespace Flax.Build
public static FieldDefinition GetField(this TypeDefinition type, string name)
{
var result = type.Fields.First(x => x.Name == name);
var result = type.Fields.FirstOrDefault(x => x.Name == name);
if (result == null)
throw new Exception($"Failed to find field '{name}' in '{type.FullName}'.");
return result;
}
public static CustomAttributeNamedArgument GetField(this CustomAttribute attribute, string name)
{
foreach (var f in attribute.Fields)
{
if (f.Name == name)
return f;
}
throw new Exception($"Failed to find field '{name}' in '{attribute.AttributeType.FullName}'.");
}
public static object GetFieldValue(this CustomAttribute attribute, string name, object defaultValue)
{
foreach (var f in attribute.Fields)
{
if (f.Name == name)
return f.Argument.Value;
}
return defaultValue;
}
public static bool IsScriptingObject(this TypeDefinition type)
{
if (type.FullName == "FlaxEngine.Object")