From 41cb7c13e37ec32ffb6557d21da079d77151e136 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 24 Oct 2022 22:19:38 -0700 Subject: Rework differentiation of member access through `[DerivativeMember(DiffType.field)]` (#2460) * wip: remove auto-diff for member access, add diff through property accessors. * Fix getter-setter test. * Fix getter-setter-multi test. * Fix nested-jvp test. * Use [DerivativeMember] attribute to differentiate through member access. * Clean up. * More cleanup. Co-authored-by: Yong He --- source/slang/slang-lower-to-ir.cpp | 79 ++++++++++++++++++++------------------ 1 file changed, 42 insertions(+), 37 deletions(-) (limited to 'source/slang/slang-lower-to-ir.cpp') diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index dc6067868..1e58a456e 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3038,38 +3038,6 @@ struct ExprLoweringVisitorBase : ExprVisitor return info; } - LoweredValInfo visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr) - { - LoweredValInfo info = lowerSubExpr(expr->inner); - - IRInst* irBaseVal = nullptr; - switch (info.flavor) - { - case LoweredValInfo::Flavor::Simple: - irBaseVal = getSimpleVal(context, info); - break; - - case LoweredValInfo::Flavor::Ptr: - irBaseVal = info.val; - break; - - default: - SLANG_UNEXPECTED("Unhandled lowered value cases"); - } - - // If the differentiable expr has an associated getter or setter, lower it - // and put it in a decoration. - // - if (expr->getterExpr != nullptr) - { - auto irGetter = lowerSubExpr(expr->getterExpr); - SLANG_ASSERT(irGetter.flavor == LoweredValInfo::Flavor::Simple); - getBuilder()->addDifferentialGetterDecoration(irBaseVal, irGetter.val); - } - - return info; - } - // Emit IR to denote the forward-mode derivative // of the inner func-expr. This will be resolved // to a concrete function during the derivative @@ -6319,7 +6287,13 @@ struct DeclLoweringVisitor : DeclVisitor // A variable declared inside of an aggregate type declaration is a member. return true; } - + if (auto extDecl = as(parent)) + { + if (auto declRefType = as(extDecl->targetType.type)) + { + return true; + } + } return false; } @@ -7108,6 +7082,14 @@ struct DeclLoweringVisitor : DeclVisitor builder->addDecoration(inst, op, operands.getBuffer(), operands.getCount()); } + void lowerDerivativeMemberModifier(IRInst* inst, DerivativeMemberAttribute* derivativeMember) + { + auto key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val; + SLANG_RELEASE_ASSERT(as(key)); + auto builder = getBuilder(); + builder->addDecoration(inst, kIROp_JVPDerivativeMemberReferenceDecoration, key); + } + LoweredValInfo lowerMemberVarDecl(VarDecl* fieldDecl) { // Each field declaration in the AST translates into @@ -7120,12 +7102,21 @@ struct DeclLoweringVisitor : DeclVisitor // will use the same space of keys. auto builder = getBuilder(); - auto irFieldKey = builder->createStructKey(); - addNameHint(context, irFieldKey, fieldDecl); + IRInst* irFieldKey = nullptr; + if (auto extVarModifier = fieldDecl->findModifier()) + { + irFieldKey = ensureDecl(context, extVarModifier->originalDecl.getDecl()).val; + SLANG_RELEASE_ASSERT(as(irFieldKey)); + } - addVarDecorations(context, irFieldKey, fieldDecl); + if (!irFieldKey) + { + irFieldKey = builder->createStructKey(); - addLinkageDecoration(context, irFieldKey, fieldDecl); + addNameHint(context, irFieldKey, fieldDecl); + addVarDecorations(context, irFieldKey, fieldDecl); + addLinkageDecoration(context, irFieldKey, fieldDecl); + } if (auto semanticModifier = fieldDecl->findModifier()) { @@ -7140,6 +7131,10 @@ struct DeclLoweringVisitor : DeclVisitor { lowerRayPayloadAccessModifier(irFieldKey, writeModifier, kIROp_StageWriteAccessDecoration); } + if (auto derivativeMemberModifier = fieldDecl->findModifier()) + { + lowerDerivativeMemberModifier(irFieldKey, derivativeMemberModifier); + } // We allow a field to be marked as a target intrinsic, // so that we can override its mangled name in the @@ -7815,6 +7810,16 @@ struct DeclLoweringVisitor : DeclVisitor getBuilder()->addJVPDerivativeMarkerDecoration(irFunc); } + // Always force inline diff setter accessor to prevent downstream compiler from complaining + // fields are not fully initialized for the first `inout` parameter. + if (as(decl)) + { + if (!decl->findModifier()) + { + getBuilder()->addForceInlineDecoration(irFunc); + } + } + FuncDeclBaseTypeInfo info; _lowerFuncDeclBaseTypeInfo( subContext, -- cgit v1.2.3