From d48cd130aacbab34bb98d51bb237ad38ff37348c Mon Sep 17 00:00:00 2001 From: kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> Date: Thu, 2 Jan 2025 14:29:57 -0600 Subject: Correct IR generation for no-diff pointer type (#5976) * Correct IR generation for no-diff pointer type Close #5805 There is an issue on checking whether a pointer type parameter is no_diff, we should first check whether this parameter is an Attribute type first, then check the data type. In the back-propagate pass, for the pointer type parameter, we should load this parameter to a temp variable, then pass it to the primal function call. Otherwise, the temp variable will no be initialized, which will cause the following calculation wrong. --- source/slang/slang-ir-autodiff-rev.cpp | 11 ++++++----- source/slang/slang-ir-autodiff-transcriber-base.cpp | 3 +++ source/slang/slang-ir-autodiff.cpp | 19 ++++++++++++++----- 3 files changed, 23 insertions(+), 10 deletions(-) (limited to 'source/slang') diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index f0ac428c7..36093518a 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -512,11 +512,12 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF { // If primal parameter is mutable, we need to pass in a temp var. auto tempVar = builder.emitVar(primalParamPtrType->getValueType()); - if (primalParamPtrType->getOp() == kIROp_InOutType) - { - // If the primal parameter is inout, we need to set the initial value. - builder.emitStore(tempVar, primalArg); - } + + // We also need to setup the initial value of the temp var, otherwise + // the temp var will be uninitialized which could cause undefined behavior + // in the primal function. + builder.emitStore(tempVar, primalArg); + primalArgs.add(tempVar); } else diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index ada35689c..1b3825a7d 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -565,6 +565,9 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType* // If this is a PtrType (out, inout, etc..), then create diff pair from // value type and re-apply the appropropriate PtrType wrapper. // + if (isNoDiffType(originalType)) + return nullptr; + if (auto origPtrType = as(originalType)) { if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 5c05b0811..4edd8eabe 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -126,13 +126,22 @@ static IRInst* _getDiffTypeWitnessFromPairType( bool isNoDiffType(IRType* paramType) { - while (auto ptrType = as(paramType)) - paramType = ptrType->getValueType(); - while (auto attrType = as(paramType)) + while (paramType) { - if (attrType->findAttr()) + if (auto attrType = as(paramType)) { - return true; + if (attrType->findAttr()) + return true; + + paramType = attrType->getBaseType(); + } + else if (auto ptrType = as(paramType)) + { + paramType = ptrType->getValueType(); + } + else + { + return false; } } return false; -- cgit v1.2.3