From a911ca6e06ce41e403b80fe6054162393491c8ac Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 13 Mar 2023 10:57:28 -0700 Subject: Support high order diff pattern: `bwd_diff(fwd_diff(f))`. (#2695) * Support high order diff pattern: `bwd_diff(fwd_diff(f))`. * Fix. --------- Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 42 ++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 12 deletions(-) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 5cd7fba45..ea8bec2bb 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1506,19 +1506,37 @@ namespace Slang aggTypeDecl->members.add(diffField); aggTypeDecl->invalidateMemberDictionary(); + // Inject a `DerivativeMember` modifier on the differential field to point to itself. + { + auto derivativeMemberModifier = m_astBuilder->create(); + auto fieldLookupExpr = m_astBuilder->create(); + fieldLookupExpr->type.type = diffMemberType; + auto baseTypeExpr = m_astBuilder->create(); + baseTypeExpr->base.type = differentialType; + auto baseTypeType = m_astBuilder->create(); + baseTypeType->type = differentialType; + baseTypeExpr->type.type = baseTypeType; + fieldLookupExpr->baseExpression = baseTypeExpr; + fieldLookupExpr->declRef = makeDeclRef(diffField); + derivativeMemberModifier->memberDeclRef = fieldLookupExpr; + addModifier(diffField, derivativeMemberModifier); + } + // Inject a `DerivativeMember` modifier on the original decl. - auto derivativeMemberModifier = m_astBuilder->create(); - auto fieldLookupExpr = m_astBuilder->create(); - fieldLookupExpr->type.type = diffMemberType; - auto baseTypeExpr = m_astBuilder->create(); - baseTypeExpr->base.type = differentialType; - auto baseTypeType = m_astBuilder->create(); - baseTypeType->type = differentialType; - baseTypeExpr->type.type = baseTypeType; - fieldLookupExpr->baseExpression = baseTypeExpr; - fieldLookupExpr->declRef = makeDeclRef(diffField); - derivativeMemberModifier->memberDeclRef = fieldLookupExpr; - addModifier(member, derivativeMemberModifier); + { + auto derivativeMemberModifier = m_astBuilder->create(); + auto fieldLookupExpr = m_astBuilder->create(); + fieldLookupExpr->type.type = diffMemberType; + auto baseTypeExpr = m_astBuilder->create(); + baseTypeExpr->base.type = differentialType; + auto baseTypeType = m_astBuilder->create(); + baseTypeType->type = differentialType; + baseTypeExpr->type.type = baseTypeType; + fieldLookupExpr->baseExpression = baseTypeExpr; + fieldLookupExpr->declRef = makeDeclRef(diffField); + derivativeMemberModifier->memberDeclRef = fieldLookupExpr; + addModifier(member, derivativeMemberModifier); + } }; // Make the Differential type itself conform to `IDifferential` interface. -- cgit v1.2.3