diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-13 10:57:28 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-13 10:57:28 -0700 |
| commit | a911ca6e06ce41e403b80fe6054162393491c8ac (patch) | |
| tree | 6c8d56a3060b1887e7fd3126fe54a1241160eddd /source/slang/slang-lower-to-ir.cpp | |
| parent | 3fea56ef77a33273bf5af6f432163b30c0a0e1dc (diff) | |
Support high order diff pattern: `bwd_diff(fwd_diff(f))`. (#2695)
* Support high order diff pattern: `bwd_diff(fwd_diff(f))`.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index d8912cbd4..5e6213205 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7285,10 +7285,18 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> builder->addDecoration(inst, op, operands.getBuffer(), operands.getCount()); } - void lowerDerivativeMemberModifier(IRInst* inst, DerivativeMemberAttribute* derivativeMember) + void lowerDerivativeMemberModifier(IRInst* inst, Decl* memberDecl, DerivativeMemberAttribute* derivativeMember) { - ensureDecl(context, derivativeMember->memberDeclRef->declRef.getDecl()->parentDecl); - auto key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val; + IRInst* key = nullptr; + if (derivativeMember->memberDeclRef->declRef.getDecl() == memberDecl) + { + key = inst; + } + else + { + ensureDecl(context, derivativeMember->memberDeclRef->declRef.getDecl()->parentDecl); + key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val; + } SLANG_RELEASE_ASSERT(as<IRStructKey>(key)); auto builder = getBuilder(); builder->addDecoration(inst, kIROp_DerivativeMemberDecoration, key); @@ -7358,7 +7366,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } if (auto derivativeMemberModifier = fieldDecl->findModifier<DerivativeMemberAttribute>()) { - lowerDerivativeMemberModifier(irFieldKey, derivativeMemberModifier); + lowerDerivativeMemberModifier(irFieldKey, fieldDecl, derivativeMemberModifier); } // We allow a field to be marked as a target intrinsic, |
