From b513d0deef521318ad943d820dd37029075a33c4 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 11 Jul 2022 23:18:06 -0400 Subject: Added support for differentiating calls to basic functions, as well as arithmetic on the float3 type (#2313) * Added support for differentiating calls to basic functions, as well as arithmetic on the float3 type * Added test expected result --- source/slang/slang-check-expr.cpp | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) (limited to 'source/slang/slang-check-expr.cpp') diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index df58b11ed..a3fec4802 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1509,6 +1509,29 @@ namespace Slang return expr; } + Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType) + { + // Only float and float3 types can be differentiated for now. + + if(primalType->equals(builder->getFloatType())) + return primalType; + else if(auto primalVectorType = as(primalType)) + { + // TODO(sai): There's probably a more elegant way to check if a type is a float3? + if (getIntVal(primalVectorType->elementCount) == 3 && primalVectorType->elementType->equals(builder->getFloatType())) + return primalVectorType; + } + return nullptr; + } + + Type* primalToJVPReturnType(ASTBuilder* builder, Type* primalType) + { + if(auto jvpType = primalToJVPParamType(builder, primalType)) + return jvpType; + else + return builder->getVoidType(); + } + Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) { // Check/Resolve inner function declaration. @@ -1524,18 +1547,10 @@ namespace Slang FuncType* jvpType = astBuilder->create(); - // Only float types can be differentiated for now. - // The JVP return type is float if primal return type is float // void otherwise. // - if (primalType->resultType->equals(astBuilder->getFloatType())) - jvpType->resultType = astBuilder->getFloatType(); - else - { - //TODO(yong): issue proper diagnostic here. - jvpType->resultType = astBuilder->getVoidType(); - } + jvpType->resultType = primalToJVPReturnType(astBuilder, primalType->getResultType()); // No support for differentiating function that throw errors, for now. SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType())); @@ -1548,8 +1563,8 @@ namespace Slang for (UInt i = 0; i < primalType->getParamCount(); i++) { - if(primalType->getParamType(i)->equals(astBuilder->getFloatType())) - jvpType->paramTypes.add(astBuilder->getFloatType()); + if(auto jvpParamType = primalToJVPParamType(astBuilder, primalType->getParamType(i))) + jvpType->paramTypes.add(jvpParamType); } expr->type = jvpType; -- cgit v1.2.3