From d48cd130aacbab34bb98d51bb237ad38ff37348c Mon Sep 17 00:00:00 2001 From: kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> Date: Thu, 2 Jan 2025 14:29:57 -0600 Subject: Correct IR generation for no-diff pointer type (#5976) * Correct IR generation for no-diff pointer type Close #5805 There is an issue on checking whether a pointer type parameter is no_diff, we should first check whether this parameter is an Attribute type first, then check the data type. In the back-propagate pass, for the pointer type parameter, we should load this parameter to a temp variable, then pass it to the primal function call. Otherwise, the temp variable will no be initialized, which will cause the following calculation wrong. --- source/slang/slang-ir-autodiff.cpp | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) (limited to 'source/slang/slang-ir-autodiff.cpp') diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 5c05b0811..4edd8eabe 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -126,13 +126,22 @@ static IRInst* _getDiffTypeWitnessFromPairType( bool isNoDiffType(IRType* paramType) { - while (auto ptrType = as(paramType)) - paramType = ptrType->getValueType(); - while (auto attrType = as(paramType)) + while (paramType) { - if (attrType->findAttr()) + if (auto attrType = as(paramType)) { - return true; + if (attrType->findAttr()) + return true; + + paramType = attrType->getBaseType(); + } + else if (auto ptrType = as(paramType)) + { + paramType = ptrType->getValueType(); + } + else + { + return false; } } return false; -- cgit v1.2.3