summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp11
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp3
-rw-r--r--source/slang/slang-ir-autodiff.cpp19
3 files changed, 23 insertions, 10 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index f0ac428c7..36093518a 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -512,11 +512,12 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF
{
// If primal parameter is mutable, we need to pass in a temp var.
auto tempVar = builder.emitVar(primalParamPtrType->getValueType());
- if (primalParamPtrType->getOp() == kIROp_InOutType)
- {
- // If the primal parameter is inout, we need to set the initial value.
- builder.emitStore(tempVar, primalArg);
- }
+
+ // We also need to setup the initial value of the temp var, otherwise
+ // the temp var will be uninitialized which could cause undefined behavior
+ // in the primal function.
+ builder.emitStore(tempVar, primalArg);
+
primalArgs.add(tempVar);
}
else
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index ada35689c..1b3825a7d 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -565,6 +565,9 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType*
// If this is a PtrType (out, inout, etc..), then create diff pair from
// value type and re-apply the appropropriate PtrType wrapper.
//
+ if (isNoDiffType(originalType))
+ return nullptr;
+
if (auto origPtrType = as<IRPtrTypeBase>(originalType))
{
if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 5c05b0811..4edd8eabe 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -126,13 +126,22 @@ static IRInst* _getDiffTypeWitnessFromPairType(
bool isNoDiffType(IRType* paramType)
{
- while (auto ptrType = as<IRPtrTypeBase>(paramType))
- paramType = ptrType->getValueType();
- while (auto attrType = as<IRAttributedType>(paramType))
+ while (paramType)
{
- if (attrType->findAttr<IRNoDiffAttr>())
+ if (auto attrType = as<IRAttributedType>(paramType))
{
- return true;
+ if (attrType->findAttr<IRNoDiffAttr>())
+ return true;
+
+ paramType = attrType->getBaseType();
+ }
+ else if (auto ptrType = as<IRPtrTypeBase>(paramType))
+ {
+ paramType = ptrType->getValueType();
+ }
+ else
+ {
+ return false;
}
}
return false;