summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-01-10 13:33:02 -0800
committerGitHub <noreply@github.com>2025-01-10 13:33:02 -0800
commit4104aa7f95e0d29e877be5208031e2670fb5a77d (patch)
treee50d7642476668589a6aa5262fa773bd382461e8 /source
parentf199640bb31e1e273e34a068ea0fb7a55f2afb5e (diff)
Fix `markNonContextParamsAsSideEffectFree`. (#6054)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp1
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp2
-rw-r--r--source/slang/slang-ir-inst-defs.h2
3 files changed, 4 insertions, 1 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 5ac4016d7..65ce69877 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -1359,6 +1359,7 @@ ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParame
auto ctxParam =
builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1));
builder->addNameHintDecoration(ctxParam, UnownedStringSlice("_s_diff_ctx"));
+ builder->addDecoration(ctxParam, kIROp_PrimalContextDecoration);
result.primalFuncParams.add(ctxParam);
result.propagateFuncParams.add(ctxParam);
result.dOutParam = dOutParam;
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 49c1d9ff7..6bc428ad6 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -342,7 +342,7 @@ void markNonContextParamsAsSideEffectFree(IRBuilder* builder, IRFunc* func)
{
for (auto param : func->getParams())
{
- if (!isIntermediateContextType(param->getDataType()))
+ if (!param->findDecorationImpl(kIROp_PrimalContextDecoration))
builder->addDecoration(param, kIROp_IgnoreSideEffectsDecoration);
}
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index f5af73dfa..2f4c69820 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -1040,6 +1040,8 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
INST(BackwardDerivativePrimalContextDecoration, BackwardDerivativePrimalContextDecoration, 1, 0)
INST(BackwardDerivativePrimalReturnDecoration, BackwardDerivativePrimalReturnDecoration, 1, 0)
+ // Mark a parameter as autodiff primal context.
+ INST(PrimalContextDecoration, PrimalContextDecoration, 0, 0)
INST(LoopCounterDecoration, loopCounterDecoration, 0, 0)
INST(LoopCounterUpdateDecoration, loopCounterUpdateDecoration, 0, 0)