diff options
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 29 |
1 files changed, 26 insertions, 3 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 03da084d3..67e8bf650 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1509,18 +1509,40 @@ namespace Slang return expr; } + // This function proceses primal params (i.e params of the inner function that is being + // differentiated) that need to be carried over to the function signature for the JVP + // function. (eg. out types can be discarded) + // + Type* primalToInputType(ASTBuilder*, Type* primalType) + { + if (auto primalOutType = as<OutType>(primalType)) + return nullptr; + else if (auto primalInOutType = as<InOutType>(primalType)) + return primalInOutType->getValueType(); + + return primalType; + } + Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType) { // Only float and float3 types can be differentiated for now. - if(primalType->equals(builder->getFloatType())) + if (primalType->equals(builder->getFloatType())) return primalType; - else if(auto primalVectorType = as<VectorExpressionType>(primalType)) + else if (auto primalVectorType = as<VectorExpressionType>(primalType)) { // TODO(sai): There's probably a more elegant way to check if a type is a float3? if (getIntVal(primalVectorType->elementCount) == 3 && primalVectorType->elementType->equals(builder->getFloatType())) return primalVectorType; } + else if (auto primalOutType = as<OutType>(primalType)) + { + return builder->getOutType(primalToJVPParamType(builder, primalOutType->getValueType())); + } + else if (auto primalInOutType = as<InOutType>(primalType)) + { + return builder->getInOutType(primalToJVPParamType(builder, primalInOutType->getValueType())); + } return nullptr; } @@ -1558,7 +1580,8 @@ namespace Slang for (UInt i = 0; i < primalType->getParamCount(); i++) { - jvpType->paramTypes.add(primalType->getParamType(i)); + if(auto primalInputType = primalToInputType(astBuilder, primalType->getParamType(i))) + jvpType->paramTypes.add(primalInputType); } for (UInt i = 0; i < primalType->getParamCount(); i++) |
