summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-07-11 23:18:06 -0400
committerGitHub <noreply@github.com>2022-07-11 23:18:06 -0400
commitb513d0deef521318ad943d820dd37029075a33c4 (patch)
treecc6dc625ae381e0461724c5b137e1a034b03e636 /source/slang/slang-check-expr.cpp
parent9261c7a23ddf061fe9f5bfc3376f09f3c0513bff (diff)
Added support for differentiating calls to basic functions, as well as arithmetic on the float3 type (#2313)
* Added support for differentiating calls to basic functions, as well as arithmetic on the float3 type * Added test expected result
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp37
1 files changed, 26 insertions, 11 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index df58b11ed..a3fec4802 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1509,6 +1509,29 @@ namespace Slang
return expr;
}
+ Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType)
+ {
+ // Only float and float3 types can be differentiated for now.
+
+ if(primalType->equals(builder->getFloatType()))
+ return primalType;
+ else if(auto primalVectorType = as<VectorExpressionType>(primalType))
+ {
+ // TODO(sai): There's probably a more elegant way to check if a type is a float3?
+ if (getIntVal(primalVectorType->elementCount) == 3 && primalVectorType->elementType->equals(builder->getFloatType()))
+ return primalVectorType;
+ }
+ return nullptr;
+ }
+
+ Type* primalToJVPReturnType(ASTBuilder* builder, Type* primalType)
+ {
+ if(auto jvpType = primalToJVPParamType(builder, primalType))
+ return jvpType;
+ else
+ return builder->getVoidType();
+ }
+
Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr)
{
// Check/Resolve inner function declaration.
@@ -1524,18 +1547,10 @@ namespace Slang
FuncType* jvpType = astBuilder->create<FuncType>();
- // Only float types can be differentiated for now.
-
// The JVP return type is float if primal return type is float
// void otherwise.
//
- if (primalType->resultType->equals(astBuilder->getFloatType()))
- jvpType->resultType = astBuilder->getFloatType();
- else
- {
- //TODO(yong): issue proper diagnostic here.
- jvpType->resultType = astBuilder->getVoidType();
- }
+ jvpType->resultType = primalToJVPReturnType(astBuilder, primalType->getResultType());
// No support for differentiating function that throw errors, for now.
SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType()));
@@ -1548,8 +1563,8 @@ namespace Slang
for (UInt i = 0; i < primalType->getParamCount(); i++)
{
- if(primalType->getParamType(i)->equals(astBuilder->getFloatType()))
- jvpType->paramTypes.add(astBuilder->getFloatType());
+ if(auto jvpParamType = primalToJVPParamType(astBuilder, primalType->getParamType(i)))
+ jvpType->paramTypes.add(jvpParamType);
}
expr->type = jvpType;