diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-10 18:46:57 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-10 18:46:57 -0800 |
| commit | aec57d849ae20a305d08348cf543d19eabc2e2d6 (patch) | |
| tree | afac620a888d27ee1000b036c4ab8c3773180af3 /source | |
| parent | 6e7b424953ae6732d4863e887e7e452396095d71 (diff) | |
Fix several autodiff bugs. (#2643)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 4 |
2 files changed, 9 insertions, 5 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index fca34f9a2..7782bd39c 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -148,7 +148,7 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns builder->markInstAsDifferential(diffRightTimesLeft, resultType); builder->markInstAsDifferential(diffSub, resultType); - auto diffMul = builder->emitMul(resultType, primalRight, primalRight); + auto diffMul = builder->emitMul(primalRight->getFullType(), primalRight, primalRight); builder->markInstAsPrimal(diffMul); auto diffDiv = builder->emitDiv(diffType, diffSub, diffMul); @@ -877,6 +877,14 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI diffBase, diffAccessChain, diffVal); builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType); } + else + { + auto primalElementType = primalVal->getDataType(); + auto zeroElementDiff = getDifferentialZeroOfType(builder, primalElementType); + diffUpdateElement = builder->emitUpdateElement( + diffBase, diffAccessChain, zeroElementDiff); + builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType); + } } } return InstPair(primalUpdateField, diffUpdateElement); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 25f6c3964..6aaa40baf 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -532,10 +532,6 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( stripTempDecorations(func); - // Run simplification to DCE unnecessary insts. - eliminateDeadCode(func); - eliminateDeadCode(primalFunc); - return primalFunc; } } // namespace Slang |
