summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-07-13 15:55:30 -0400
committerGitHub <noreply@github.com>2022-07-13 15:55:30 -0400
commit4af61e2296a49876c2d9e7cf192ae825302a83de (patch)
treed067b944b9794fe5061bbf51e8ef6a39d5fcefbf /source/slang/slang-check-expr.cpp
parent564f0d84a9c5276c05e8288955a7685f96278d1b (diff)
Added support for differentiating out and inout parameters. (#2323)
* Added out/inout tests * Added support for out and inout parameters. Still untested * Fixed and tested support for out and inout types * Removed some comments
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp29
1 files changed, 26 insertions, 3 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 03da084d3..67e8bf650 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1509,18 +1509,40 @@ namespace Slang
return expr;
}
+ // This function proceses primal params (i.e params of the inner function that is being
+ // differentiated) that need to be carried over to the function signature for the JVP
+ // function. (eg. out types can be discarded)
+ //
+ Type* primalToInputType(ASTBuilder*, Type* primalType)
+ {
+ if (auto primalOutType = as<OutType>(primalType))
+ return nullptr;
+ else if (auto primalInOutType = as<InOutType>(primalType))
+ return primalInOutType->getValueType();
+
+ return primalType;
+ }
+
Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType)
{
// Only float and float3 types can be differentiated for now.
- if(primalType->equals(builder->getFloatType()))
+ if (primalType->equals(builder->getFloatType()))
return primalType;
- else if(auto primalVectorType = as<VectorExpressionType>(primalType))
+ else if (auto primalVectorType = as<VectorExpressionType>(primalType))
{
// TODO(sai): There's probably a more elegant way to check if a type is a float3?
if (getIntVal(primalVectorType->elementCount) == 3 && primalVectorType->elementType->equals(builder->getFloatType()))
return primalVectorType;
}
+ else if (auto primalOutType = as<OutType>(primalType))
+ {
+ return builder->getOutType(primalToJVPParamType(builder, primalOutType->getValueType()));
+ }
+ else if (auto primalInOutType = as<InOutType>(primalType))
+ {
+ return builder->getInOutType(primalToJVPParamType(builder, primalInOutType->getValueType()));
+ }
return nullptr;
}
@@ -1558,7 +1580,8 @@ namespace Slang
for (UInt i = 0; i < primalType->getParamCount(); i++)
{
- jvpType->paramTypes.add(primalType->getParamType(i));
+ if(auto primalInputType = primalToInputType(astBuilder, primalType->getParamType(i)))
+ jvpType->paramTypes.add(primalInputType);
}
for (UInt i = 0; i < primalType->getParamCount(); i++)