diff --git a/Source/Engine/Visject/ShaderGraph.cpp b/Source/Engine/Visject/ShaderGraph.cpp index 688e46382..faf344263 100644 --- a/Source/Engine/Visject/ShaderGraph.cpp +++ b/Source/Engine/Visject/ShaderGraph.cpp @@ -146,6 +146,8 @@ void ShaderGenerator::ProcessGroupMath(Box* box, Node* node, Value& value) Box* b2 = node->GetBox(1); Value v1 = tryGetValue(b1, 0, Value::Zero); Value v2 = tryGetValue(b2, 1, Value::Zero); + if (SanitizeMathValue(v1, node, b1, &value)) + break; if (b1->HasConnection()) v2 = v2.Cast(v1.Type); else @@ -251,7 +253,10 @@ void ShaderGenerator::ProcessGroupMath(Box* box, Node* node, Value& value) // Lerp case 25: { - Value a = tryGetValue(node->GetBox(0), 0, Value::Zero); + auto boxA = node->GetBox(0); + Value a = tryGetValue(boxA, 0, Value::Zero); + if (SanitizeMathValue(a, node, boxA, &value)) + break; Value b = tryGetValue(node->GetBox(1), 1, Value::One).Cast(a.Type); Value alpha = tryGetValue(node->GetBox(2), 2, Value::Zero).Cast(ValueType::Float); String text = String::Format(TEXT("lerp({0}, {1}, {2})"), a.Value, b.Value, alpha.Value); @@ -1364,6 +1369,20 @@ SerializedMaterialParam& ShaderGenerator::findOrAddGlobalSDF() return param; } +bool ShaderGenerator::SanitizeMathValue(Value& value, Node* node, Box* box, Value* resultOnInvalid) +{ + bool invalid = value.Type == VariantType::Object; + if (invalid) + { + OnError(node, box, TEXT("Invalid input type for math operation")); + if (resultOnInvalid) + *resultOnInvalid = Value::Zero; + else + value = Value::Zero; + } + return invalid; +} + String ShaderGenerator::getLocalName(int32 index) { return TEXT("local") + StringUtils::ToString(index); diff --git a/Source/Engine/Visject/ShaderGraph.h b/Source/Engine/Visject/ShaderGraph.h index ab0e7d405..5b17604d0 100644 --- a/Source/Engine/Visject/ShaderGraph.h +++ b/Source/Engine/Visject/ShaderGraph.h @@ -255,6 +255,8 @@ protected: SerializedMaterialParam& findOrAddTextureGroupSampler(int32 index); SerializedMaterialParam& findOrAddGlobalSDF(); + bool SanitizeMathValue(Value& value, Node* node, Box* box, Value* resultOnInvalid = nullptr); + static String getLocalName(int32 index); static String getParamName(int32 index); };