From 4104aa7f95e0d29e877be5208031e2670fb5a77d Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 10 Jan 2025 13:33:02 -0800 Subject: Fix `markNonContextParamsAsSideEffectFree`. (#6054) --- source/slang/slang-ir-autodiff-rev.cpp | 1 + source/slang/slang-ir-autodiff-unzip.cpp | 2 +- source/slang/slang-ir-inst-defs.h | 2 ++ 3 files changed, 4 insertions(+), 1 deletion(-) (limited to 'source') diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 5ac4016d7..65ce69877 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -1359,6 +1359,7 @@ ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParame auto ctxParam = builder->emitParam(as(diffFunc->getDataType())->getParamType(paramCount - 1)); builder->addNameHintDecoration(ctxParam, UnownedStringSlice("_s_diff_ctx")); + builder->addDecoration(ctxParam, kIROp_PrimalContextDecoration); result.primalFuncParams.add(ctxParam); result.propagateFuncParams.add(ctxParam); result.dOutParam = dOutParam; diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 49c1d9ff7..6bc428ad6 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -342,7 +342,7 @@ void markNonContextParamsAsSideEffectFree(IRBuilder* builder, IRFunc* func) { for (auto param : func->getParams()) { - if (!isIntermediateContextType(param->getDataType())) + if (!param->findDecorationImpl(kIROp_PrimalContextDecoration)) builder->addDecoration(param, kIROp_IgnoreSideEffectsDecoration); } } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index f5af73dfa..2f4c69820 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -1040,6 +1040,8 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace) INST(BackwardDerivativePrimalContextDecoration, BackwardDerivativePrimalContextDecoration, 1, 0) INST(BackwardDerivativePrimalReturnDecoration, BackwardDerivativePrimalReturnDecoration, 1, 0) + // Mark a parameter as autodiff primal context. + INST(PrimalContextDecoration, PrimalContextDecoration, 0, 0) INST(LoopCounterDecoration, loopCounterDecoration, 0, 0) INST(LoopCounterUpdateDecoration, loopCounterUpdateDecoration, 0, 0) -- cgit v1.2.3