summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp1
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp2
-rw-r--r--source/slang/slang-ir-inst-defs.h2
3 files changed, 4 insertions, 1 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 5ac4016d7..65ce69877 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -1359,6 +1359,7 @@ ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParame
auto ctxParam =
builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1));
builder->addNameHintDecoration(ctxParam, UnownedStringSlice("_s_diff_ctx"));
+ builder->addDecoration(ctxParam, kIROp_PrimalContextDecoration);
result.primalFuncParams.add(ctxParam);
result.propagateFuncParams.add(ctxParam);
result.dOutParam = dOutParam;
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 49c1d9ff7..6bc428ad6 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -342,7 +342,7 @@ void markNonContextParamsAsSideEffectFree(IRBuilder* builder, IRFunc* func)
{
for (auto param : func->getParams())
{
- if (!isIntermediateContextType(param->getDataType()))
+ if (!param->findDecorationImpl(kIROp_PrimalContextDecoration))
builder->addDecoration(param, kIROp_IgnoreSideEffectsDecoration);
}
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index f5af73dfa..2f4c69820 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -1040,6 +1040,8 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
INST(BackwardDerivativePrimalContextDecoration, BackwardDerivativePrimalContextDecoration, 1, 0)
INST(BackwardDerivativePrimalReturnDecoration, BackwardDerivativePrimalReturnDecoration, 1, 0)
+ // Mark a parameter as autodiff primal context.
+ INST(PrimalContextDecoration, PrimalContextDecoration, 0, 0)
INST(LoopCounterDecoration, loopCounterDecoration, 0, 0)
INST(LoopCounterUpdateDecoration, loopCounterUpdateDecoration, 0, 0)