diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-07-18 23:32:30 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-07-18 23:32:30 -0400 |
| commit | 5b4f35b8d00661852c607a49d81c590d4050a166 (patch) | |
| tree | 320027c51f44c83c731d9121e41453dda67ed3ce /source/slang/slang-check-expr.cpp | |
| parent | 2e4b5770fa7e6dbf56845382706b33a22d6a025b (diff) | |
Added forward-mode autodiff support for more instructions (#2331)
* Merge slang-ir-diff-jvp.cpp
* Added support and tests for other float vector types
* Added swizzle test and code to handle it (tests failing currently)
* Fixed one test, the other is still pending
* Fixed instruction cloning logic to avoid modifying original function
* Fixed an issue with custom 'pow_jvp' and added support for vector contructor
* Minor update to comments
* Fixed support for division
* Fixed an issue with uninitialized diagnostic sink
* Moved derivative processing to after mandatory inlining.
Skip instructions that don't have side-effects and aren't used by anything.
* WIP: Handling unconditional control flow and multi-block functions
* Support for unconditional multi-block functions
* Added a dead code elimination step to the derivative pass
* Changed name of 'hasNoSideEffects()'
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 67e8bf650..1895da70b 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1525,15 +1525,14 @@ namespace Slang Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType) { - // Only float and float3 types can be differentiated for now. + // Only float and vector<float> types can be differentiated for now. if (primalType->equals(builder->getFloatType())) return primalType; else if (auto primalVectorType = as<VectorExpressionType>(primalType)) { - // TODO(sai): There's probably a more elegant way to check if a type is a float3? - if (getIntVal(primalVectorType->elementCount) == 3 && primalVectorType->elementType->equals(builder->getFloatType())) - return primalVectorType; + if (auto jvpElementType = primalToJVPParamType(builder, primalVectorType->elementType)) + return builder->getVectorType(jvpElementType, primalVectorType->elementCount); } else if (auto primalOutType = as<OutType>(primalType)) { |
