diff options
| author | Yong He <yonghe@outlook.com> | 2022-10-24 22:19:38 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-24 22:19:38 -0700 |
| commit | 41cb7c13e37ec32ffb6557d21da079d77151e136 (patch) | |
| tree | 38d2c44938e2679c42c5c0e73f5411e59015df93 /source/slang/slang-lower-to-ir.cpp | |
| parent | 1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (diff) | |
Rework differentiation of member access through `[DerivativeMember(DiffType.field)]` (#2460)
* wip: remove auto-diff for member access, add diff through property accessors.
* Fix getter-setter test.
* Fix getter-setter-multi test.
* Fix nested-jvp test.
* Use [DerivativeMember] attribute to differentiate through member access.
* Clean up.
* More cleanup.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 79 |
1 files changed, 42 insertions, 37 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index dc6067868..1e58a456e 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3038,38 +3038,6 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return info; } - LoweredValInfo visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr) - { - LoweredValInfo info = lowerSubExpr(expr->inner); - - IRInst* irBaseVal = nullptr; - switch (info.flavor) - { - case LoweredValInfo::Flavor::Simple: - irBaseVal = getSimpleVal(context, info); - break; - - case LoweredValInfo::Flavor::Ptr: - irBaseVal = info.val; - break; - - default: - SLANG_UNEXPECTED("Unhandled lowered value cases"); - } - - // If the differentiable expr has an associated getter or setter, lower it - // and put it in a decoration. - // - if (expr->getterExpr != nullptr) - { - auto irGetter = lowerSubExpr(expr->getterExpr); - SLANG_ASSERT(irGetter.flavor == LoweredValInfo::Flavor::Simple); - getBuilder()->addDifferentialGetterDecoration(irBaseVal, irGetter.val); - } - - return info; - } - // Emit IR to denote the forward-mode derivative // of the inner func-expr. This will be resolved // to a concrete function during the derivative @@ -6319,7 +6287,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // A variable declared inside of an aggregate type declaration is a member. return true; } - + if (auto extDecl = as<ExtensionDecl>(parent)) + { + if (auto declRefType = as<DeclRefType>(extDecl->targetType.type)) + { + return true; + } + } return false; } @@ -7108,6 +7082,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> builder->addDecoration(inst, op, operands.getBuffer(), operands.getCount()); } + void lowerDerivativeMemberModifier(IRInst* inst, DerivativeMemberAttribute* derivativeMember) + { + auto key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val; + SLANG_RELEASE_ASSERT(as<IRStructKey>(key)); + auto builder = getBuilder(); + builder->addDecoration(inst, kIROp_JVPDerivativeMemberReferenceDecoration, key); + } + LoweredValInfo lowerMemberVarDecl(VarDecl* fieldDecl) { // Each field declaration in the AST translates into @@ -7120,12 +7102,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // will use the same space of keys. auto builder = getBuilder(); - auto irFieldKey = builder->createStructKey(); - addNameHint(context, irFieldKey, fieldDecl); + IRInst* irFieldKey = nullptr; + if (auto extVarModifier = fieldDecl->findModifier<ExtensionExternVarModifier>()) + { + irFieldKey = ensureDecl(context, extVarModifier->originalDecl.getDecl()).val; + SLANG_RELEASE_ASSERT(as<IRStructKey>(irFieldKey)); + } - addVarDecorations(context, irFieldKey, fieldDecl); + if (!irFieldKey) + { + irFieldKey = builder->createStructKey(); - addLinkageDecoration(context, irFieldKey, fieldDecl); + addNameHint(context, irFieldKey, fieldDecl); + addVarDecorations(context, irFieldKey, fieldDecl); + addLinkageDecoration(context, irFieldKey, fieldDecl); + } if (auto semanticModifier = fieldDecl->findModifier<HLSLSimpleSemantic>()) { @@ -7140,6 +7131,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { lowerRayPayloadAccessModifier(irFieldKey, writeModifier, kIROp_StageWriteAccessDecoration); } + if (auto derivativeMemberModifier = fieldDecl->findModifier<DerivativeMemberAttribute>()) + { + lowerDerivativeMemberModifier(irFieldKey, derivativeMemberModifier); + } // We allow a field to be marked as a target intrinsic, // so that we can override its mangled name in the @@ -7815,6 +7810,16 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getBuilder()->addJVPDerivativeMarkerDecoration(irFunc); } + // Always force inline diff setter accessor to prevent downstream compiler from complaining + // fields are not fully initialized for the first `inout` parameter. + if (as<SetterDecl>(decl)) + { + if (!decl->findModifier<ForceInlineAttribute>()) + { + getBuilder()->addForceInlineDecoration(irFunc); + } + } + FuncDeclBaseTypeInfo info; _lowerFuncDeclBaseTypeInfo( subContext, |
