From 4af61e2296a49876c2d9e7cf192ae825302a83de Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 13 Jul 2022 15:55:30 -0400 Subject: Added support for differentiating out and inout parameters. (#2323) * Added out/inout tests * Added support for out and inout parameters. Still untested * Fixed and tested support for out and inout types * Removed some comments --- source/slang/slang-check-expr.cpp | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) (limited to 'source/slang/slang-check-expr.cpp') diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 03da084d3..67e8bf650 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1509,18 +1509,40 @@ namespace Slang return expr; } + // This function proceses primal params (i.e params of the inner function that is being + // differentiated) that need to be carried over to the function signature for the JVP + // function. (eg. out types can be discarded) + // + Type* primalToInputType(ASTBuilder*, Type* primalType) + { + if (auto primalOutType = as(primalType)) + return nullptr; + else if (auto primalInOutType = as(primalType)) + return primalInOutType->getValueType(); + + return primalType; + } + Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType) { // Only float and float3 types can be differentiated for now. - if(primalType->equals(builder->getFloatType())) + if (primalType->equals(builder->getFloatType())) return primalType; - else if(auto primalVectorType = as(primalType)) + else if (auto primalVectorType = as(primalType)) { // TODO(sai): There's probably a more elegant way to check if a type is a float3? if (getIntVal(primalVectorType->elementCount) == 3 && primalVectorType->elementType->equals(builder->getFloatType())) return primalVectorType; } + else if (auto primalOutType = as(primalType)) + { + return builder->getOutType(primalToJVPParamType(builder, primalOutType->getValueType())); + } + else if (auto primalInOutType = as(primalType)) + { + return builder->getInOutType(primalToJVPParamType(builder, primalInOutType->getValueType())); + } return nullptr; } @@ -1558,7 +1580,8 @@ namespace Slang for (UInt i = 0; i < primalType->getParamCount(); i++) { - jvpType->paramTypes.add(primalType->getParamType(i)); + if(auto primalInputType = primalToInputType(astBuilder, primalType->getParamType(i))) + jvpType->paramTypes.add(primalInputType); } for (UInt i = 0; i < primalType->getParamCount(); i++) -- cgit v1.2.3