From 5b4f35b8d00661852c607a49d81c590d4050a166 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 18 Jul 2022 23:32:30 -0400 Subject: Added forward-mode autodiff support for more instructions (#2331) * Merge slang-ir-diff-jvp.cpp * Added support and tests for other float vector types * Added swizzle test and code to handle it (tests failing currently) * Fixed one test, the other is still pending * Fixed instruction cloning logic to avoid modifying original function * Fixed an issue with custom 'pow_jvp' and added support for vector contructor * Minor update to comments * Fixed support for division * Fixed an issue with uninitialized diagnostic sink * Moved derivative processing to after mandatory inlining. Skip instructions that don't have side-effects and aren't used by anything. * WIP: Handling unconditional control flow and multi-block functions * Support for unconditional multi-block functions * Added a dead code elimination step to the derivative pass * Changed name of 'hasNoSideEffects()' --- source/slang/slang-check-expr.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'source/slang/slang-check-expr.cpp') diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 67e8bf650..1895da70b 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1525,15 +1525,14 @@ namespace Slang Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType) { - // Only float and float3 types can be differentiated for now. + // Only float and vector types can be differentiated for now. if (primalType->equals(builder->getFloatType())) return primalType; else if (auto primalVectorType = as(primalType)) { - // TODO(sai): There's probably a more elegant way to check if a type is a float3? - if (getIntVal(primalVectorType->elementCount) == 3 && primalVectorType->elementType->equals(builder->getFloatType())) - return primalVectorType; + if (auto jvpElementType = primalToJVPParamType(builder, primalVectorType->elementType)) + return builder->getVectorType(jvpElementType, primalVectorType->elementCount); } else if (auto primalOutType = as(primalType)) { -- cgit v1.2.3