diff options
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 67e8bf650..1895da70b 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1525,15 +1525,14 @@ namespace Slang Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType) { - // Only float and float3 types can be differentiated for now. + // Only float and vector<float> types can be differentiated for now. if (primalType->equals(builder->getFloatType())) return primalType; else if (auto primalVectorType = as<VectorExpressionType>(primalType)) { - // TODO(sai): There's probably a more elegant way to check if a type is a float3? - if (getIntVal(primalVectorType->elementCount) == 3 && primalVectorType->elementType->equals(builder->getFloatType())) - return primalVectorType; + if (auto jvpElementType = primalToJVPParamType(builder, primalVectorType->elementType)) + return builder->getVectorType(jvpElementType, primalVectorType->elementCount); } else if (auto primalOutType = as<OutType>(primalType)) { |
