summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-13 10:57:28 -0700
committerGitHub <noreply@github.com>2023-03-13 10:57:28 -0700
commita911ca6e06ce41e403b80fe6054162393491c8ac (patch)
tree6c8d56a3060b1887e7fd3126fe54a1241160eddd /source/slang/slang-check-decl.cpp
parent3fea56ef77a33273bf5af6f432163b30c0a0e1dc (diff)
Support high order diff pattern: `bwd_diff(fwd_diff(f))`. (#2695)
* Support high order diff pattern: `bwd_diff(fwd_diff(f))`. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp42
1 files changed, 30 insertions, 12 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 5cd7fba45..ea8bec2bb 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1506,19 +1506,37 @@ namespace Slang
aggTypeDecl->members.add(diffField);
aggTypeDecl->invalidateMemberDictionary();
+ // Inject a `DerivativeMember` modifier on the differential field to point to itself.
+ {
+ auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>();
+ auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>();
+ fieldLookupExpr->type.type = diffMemberType;
+ auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
+ baseTypeExpr->base.type = differentialType;
+ auto baseTypeType = m_astBuilder->create<TypeType>();
+ baseTypeType->type = differentialType;
+ baseTypeExpr->type.type = baseTypeType;
+ fieldLookupExpr->baseExpression = baseTypeExpr;
+ fieldLookupExpr->declRef = makeDeclRef(diffField);
+ derivativeMemberModifier->memberDeclRef = fieldLookupExpr;
+ addModifier(diffField, derivativeMemberModifier);
+ }
+
// Inject a `DerivativeMember` modifier on the original decl.
- auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>();
- auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>();
- fieldLookupExpr->type.type = diffMemberType;
- auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
- baseTypeExpr->base.type = differentialType;
- auto baseTypeType = m_astBuilder->create<TypeType>();
- baseTypeType->type = differentialType;
- baseTypeExpr->type.type = baseTypeType;
- fieldLookupExpr->baseExpression = baseTypeExpr;
- fieldLookupExpr->declRef = makeDeclRef(diffField);
- derivativeMemberModifier->memberDeclRef = fieldLookupExpr;
- addModifier(member, derivativeMemberModifier);
+ {
+ auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>();
+ auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>();
+ fieldLookupExpr->type.type = diffMemberType;
+ auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
+ baseTypeExpr->base.type = differentialType;
+ auto baseTypeType = m_astBuilder->create<TypeType>();
+ baseTypeType->type = differentialType;
+ baseTypeExpr->type.type = baseTypeType;
+ fieldLookupExpr->baseExpression = baseTypeExpr;
+ fieldLookupExpr->declRef = makeDeclRef(diffField);
+ derivativeMemberModifier->memberDeclRef = fieldLookupExpr;
+ addModifier(member, derivativeMemberModifier);
+ }
};
// Make the Differential type itself conform to `IDifferential` interface.