summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-10 18:46:57 -0800
committerGitHub <noreply@github.com>2023-02-10 18:46:57 -0800
commitaec57d849ae20a305d08348cf543d19eabc2e2d6 (patch)
treeafac620a888d27ee1000b036c4ab8c3773180af3 /source
parent6e7b424953ae6732d4863e887e7e452396095d71 (diff)
Fix several autodiff bugs. (#2643)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp10
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp4
2 files changed, 9 insertions, 5 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index fca34f9a2..7782bd39c 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -148,7 +148,7 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns
builder->markInstAsDifferential(diffRightTimesLeft, resultType);
builder->markInstAsDifferential(diffSub, resultType);
- auto diffMul = builder->emitMul(resultType, primalRight, primalRight);
+ auto diffMul = builder->emitMul(primalRight->getFullType(), primalRight, primalRight);
builder->markInstAsPrimal(diffMul);
auto diffDiv = builder->emitDiv(diffType, diffSub, diffMul);
@@ -877,6 +877,14 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI
diffBase, diffAccessChain, diffVal);
builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType);
}
+ else
+ {
+ auto primalElementType = primalVal->getDataType();
+ auto zeroElementDiff = getDifferentialZeroOfType(builder, primalElementType);
+ diffUpdateElement = builder->emitUpdateElement(
+ diffBase, diffAccessChain, zeroElementDiff);
+ builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType);
+ }
}
}
return InstPair(primalUpdateField, diffUpdateElement);
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 25f6c3964..6aaa40baf 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -532,10 +532,6 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
stripTempDecorations(func);
- // Run simplification to DCE unnecessary insts.
- eliminateDeadCode(func);
- eliminateDeadCode(primalFunc);
-
return primalFunc;
}
} // namespace Slang