summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-07-18 23:32:30 -0400
committerGitHub <noreply@github.com>2022-07-18 23:32:30 -0400
commit5b4f35b8d00661852c607a49d81c590d4050a166 (patch)
tree320027c51f44c83c731d9121e41453dda67ed3ce /source/slang/slang-check-expr.cpp
parent2e4b5770fa7e6dbf56845382706b33a22d6a025b (diff)
Added forward-mode autodiff support for more instructions (#2331)
* Merge slang-ir-diff-jvp.cpp * Added support and tests for other float vector types * Added swizzle test and code to handle it (tests failing currently) * Fixed one test, the other is still pending * Fixed instruction cloning logic to avoid modifying original function * Fixed an issue with custom 'pow_jvp' and added support for vector contructor * Minor update to comments * Fixed support for division * Fixed an issue with uninitialized diagnostic sink * Moved derivative processing to after mandatory inlining. Skip instructions that don't have side-effects and aren't used by anything. * WIP: Handling unconditional control flow and multi-block functions * Support for unconditional multi-block functions * Added a dead code elimination step to the derivative pass * Changed name of 'hasNoSideEffects()'
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp7
1 files changed, 3 insertions, 4 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 67e8bf650..1895da70b 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1525,15 +1525,14 @@ namespace Slang
Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType)
{
- // Only float and float3 types can be differentiated for now.
+ // Only float and vector<float> types can be differentiated for now.
if (primalType->equals(builder->getFloatType()))
return primalType;
else if (auto primalVectorType = as<VectorExpressionType>(primalType))
{
- // TODO(sai): There's probably a more elegant way to check if a type is a float3?
- if (getIntVal(primalVectorType->elementCount) == 3 && primalVectorType->elementType->equals(builder->getFloatType()))
- return primalVectorType;
+ if (auto jvpElementType = primalToJVPParamType(builder, primalVectorType->elementType))
+ return builder->getVectorType(jvpElementType, primalVectorType->elementCount);
}
else if (auto primalOutType = as<OutType>(primalType))
{