From a911ca6e06ce41e403b80fe6054162393491c8ac Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 13 Mar 2023 10:57:28 -0700 Subject: Support high order diff pattern: `bwd_diff(fwd_diff(f))`. (#2695) * Support high order diff pattern: `bwd_diff(fwd_diff(f))`. * Fix. --------- Co-authored-by: Yong He --- source/slang/slang-lower-to-ir.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) (limited to 'source/slang/slang-lower-to-ir.cpp') diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index d8912cbd4..5e6213205 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7285,10 +7285,18 @@ struct DeclLoweringVisitor : DeclVisitor builder->addDecoration(inst, op, operands.getBuffer(), operands.getCount()); } - void lowerDerivativeMemberModifier(IRInst* inst, DerivativeMemberAttribute* derivativeMember) + void lowerDerivativeMemberModifier(IRInst* inst, Decl* memberDecl, DerivativeMemberAttribute* derivativeMember) { - ensureDecl(context, derivativeMember->memberDeclRef->declRef.getDecl()->parentDecl); - auto key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val; + IRInst* key = nullptr; + if (derivativeMember->memberDeclRef->declRef.getDecl() == memberDecl) + { + key = inst; + } + else + { + ensureDecl(context, derivativeMember->memberDeclRef->declRef.getDecl()->parentDecl); + key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val; + } SLANG_RELEASE_ASSERT(as(key)); auto builder = getBuilder(); builder->addDecoration(inst, kIROp_DerivativeMemberDecoration, key); @@ -7358,7 +7366,7 @@ struct DeclLoweringVisitor : DeclVisitor } if (auto derivativeMemberModifier = fieldDecl->findModifier()) { - lowerDerivativeMemberModifier(irFieldKey, derivativeMemberModifier); + lowerDerivativeMemberModifier(irFieldKey, fieldDecl, derivativeMemberModifier); } // We allow a field to be marked as a target intrinsic, -- cgit v1.2.3