summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp7
1 files changed, 3 insertions, 4 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 67e8bf650..1895da70b 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1525,15 +1525,14 @@ namespace Slang
Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType)
{
- // Only float and float3 types can be differentiated for now.
+ // Only float and vector<float> types can be differentiated for now.
if (primalType->equals(builder->getFloatType()))
return primalType;
else if (auto primalVectorType = as<VectorExpressionType>(primalType))
{
- // TODO(sai): There's probably a more elegant way to check if a type is a float3?
- if (getIntVal(primalVectorType->elementCount) == 3 && primalVectorType->elementType->equals(builder->getFloatType()))
- return primalVectorType;
+ if (auto jvpElementType = primalToJVPParamType(builder, primalVectorType->elementType))
+ return builder->getVectorType(jvpElementType, primalVectorType->elementCount);
}
else if (auto primalOutType = as<OutType>(primalType))
{