diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-07-13 15:55:30 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-07-13 15:55:30 -0400 |
| commit | 4af61e2296a49876c2d9e7cf192ae825302a83de (patch) | |
| tree | d067b944b9794fe5061bbf51e8ef6a39d5fcefbf /source/slang/slang-check-expr.cpp | |
| parent | 564f0d84a9c5276c05e8288955a7685f96278d1b (diff) | |
Added support for differentiating out and inout parameters. (#2323)
* Added out/inout tests
* Added support for out and inout parameters. Still untested
* Fixed and tested support for out and inout types
* Removed some comments
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 29 |
1 files changed, 26 insertions, 3 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 03da084d3..67e8bf650 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1509,18 +1509,40 @@ namespace Slang return expr; } + // This function proceses primal params (i.e params of the inner function that is being + // differentiated) that need to be carried over to the function signature for the JVP + // function. (eg. out types can be discarded) + // + Type* primalToInputType(ASTBuilder*, Type* primalType) + { + if (auto primalOutType = as<OutType>(primalType)) + return nullptr; + else if (auto primalInOutType = as<InOutType>(primalType)) + return primalInOutType->getValueType(); + + return primalType; + } + Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType) { // Only float and float3 types can be differentiated for now. - if(primalType->equals(builder->getFloatType())) + if (primalType->equals(builder->getFloatType())) return primalType; - else if(auto primalVectorType = as<VectorExpressionType>(primalType)) + else if (auto primalVectorType = as<VectorExpressionType>(primalType)) { // TODO(sai): There's probably a more elegant way to check if a type is a float3? if (getIntVal(primalVectorType->elementCount) == 3 && primalVectorType->elementType->equals(builder->getFloatType())) return primalVectorType; } + else if (auto primalOutType = as<OutType>(primalType)) + { + return builder->getOutType(primalToJVPParamType(builder, primalOutType->getValueType())); + } + else if (auto primalInOutType = as<InOutType>(primalType)) + { + return builder->getInOutType(primalToJVPParamType(builder, primalInOutType->getValueType())); + } return nullptr; } @@ -1558,7 +1580,8 @@ namespace Slang for (UInt i = 0; i < primalType->getParamCount(); i++) { - jvpType->paramTypes.add(primalType->getParamType(i)); + if(auto primalInputType = primalToInputType(astBuilder, primalType->getParamType(i))) + jvpType->paramTypes.add(primalInputType); } for (UInt i = 0; i < primalType->getParamCount(); i++) |
