From aec57d849ae20a305d08348cf543d19eabc2e2d6 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 10 Feb 2023 18:46:57 -0800 Subject: Fix several autodiff bugs. (#2643) --- source/slang/slang-ir-autodiff-fwd.cpp | 10 +++++++++- source/slang/slang-ir-autodiff-unzip.cpp | 4 ---- 2 files changed, 9 insertions(+), 5 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index fca34f9a2..7782bd39c 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -148,7 +148,7 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns builder->markInstAsDifferential(diffRightTimesLeft, resultType); builder->markInstAsDifferential(diffSub, resultType); - auto diffMul = builder->emitMul(resultType, primalRight, primalRight); + auto diffMul = builder->emitMul(primalRight->getFullType(), primalRight, primalRight); builder->markInstAsPrimal(diffMul); auto diffDiv = builder->emitDiv(diffType, diffSub, diffMul); @@ -877,6 +877,14 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI diffBase, diffAccessChain, diffVal); builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType); } + else + { + auto primalElementType = primalVal->getDataType(); + auto zeroElementDiff = getDifferentialZeroOfType(builder, primalElementType); + diffUpdateElement = builder->emitUpdateElement( + diffBase, diffAccessChain, zeroElementDiff); + builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType); + } } } return InstPair(primalUpdateField, diffUpdateElement); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 25f6c3964..6aaa40baf 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -532,10 +532,6 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( stripTempDecorations(func); - // Run simplification to DCE unnecessary insts. - eliminateDeadCode(func); - eliminateDeadCode(primalFunc); - return primalFunc; } } // namespace Slang -- cgit v1.2.3