summaryrefslogtreecommitdiff
path: root/source/slang/slang-lower-to-ir.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-10-24 22:19:38 -0700
committerGitHub <noreply@github.com>2022-10-24 22:19:38 -0700
commit41cb7c13e37ec32ffb6557d21da079d77151e136 (patch)
tree38d2c44938e2679c42c5c0e73f5411e59015df93 /source/slang/slang-lower-to-ir.cpp
parent1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (diff)
Rework differentiation of member access through `[DerivativeMember(DiffType.field)]` (#2460)
* wip: remove auto-diff for member access, add diff through property accessors. * Fix getter-setter test. * Fix getter-setter-multi test. * Fix nested-jvp test. * Use [DerivativeMember] attribute to differentiate through member access. * Clean up. * More cleanup. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
-rw-r--r--source/slang/slang-lower-to-ir.cpp79
1 files changed, 42 insertions, 37 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index dc6067868..1e58a456e 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3038,38 +3038,6 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
return info;
}
- LoweredValInfo visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr)
- {
- LoweredValInfo info = lowerSubExpr(expr->inner);
-
- IRInst* irBaseVal = nullptr;
- switch (info.flavor)
- {
- case LoweredValInfo::Flavor::Simple:
- irBaseVal = getSimpleVal(context, info);
- break;
-
- case LoweredValInfo::Flavor::Ptr:
- irBaseVal = info.val;
- break;
-
- default:
- SLANG_UNEXPECTED("Unhandled lowered value cases");
- }
-
- // If the differentiable expr has an associated getter or setter, lower it
- // and put it in a decoration.
- //
- if (expr->getterExpr != nullptr)
- {
- auto irGetter = lowerSubExpr(expr->getterExpr);
- SLANG_ASSERT(irGetter.flavor == LoweredValInfo::Flavor::Simple);
- getBuilder()->addDifferentialGetterDecoration(irBaseVal, irGetter.val);
- }
-
- return info;
- }
-
// Emit IR to denote the forward-mode derivative
// of the inner func-expr. This will be resolved
// to a concrete function during the derivative
@@ -6319,7 +6287,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// A variable declared inside of an aggregate type declaration is a member.
return true;
}
-
+ if (auto extDecl = as<ExtensionDecl>(parent))
+ {
+ if (auto declRefType = as<DeclRefType>(extDecl->targetType.type))
+ {
+ return true;
+ }
+ }
return false;
}
@@ -7108,6 +7082,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
builder->addDecoration(inst, op, operands.getBuffer(), operands.getCount());
}
+ void lowerDerivativeMemberModifier(IRInst* inst, DerivativeMemberAttribute* derivativeMember)
+ {
+ auto key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val;
+ SLANG_RELEASE_ASSERT(as<IRStructKey>(key));
+ auto builder = getBuilder();
+ builder->addDecoration(inst, kIROp_JVPDerivativeMemberReferenceDecoration, key);
+ }
+
LoweredValInfo lowerMemberVarDecl(VarDecl* fieldDecl)
{
// Each field declaration in the AST translates into
@@ -7120,12 +7102,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// will use the same space of keys.
auto builder = getBuilder();
- auto irFieldKey = builder->createStructKey();
- addNameHint(context, irFieldKey, fieldDecl);
+ IRInst* irFieldKey = nullptr;
+ if (auto extVarModifier = fieldDecl->findModifier<ExtensionExternVarModifier>())
+ {
+ irFieldKey = ensureDecl(context, extVarModifier->originalDecl.getDecl()).val;
+ SLANG_RELEASE_ASSERT(as<IRStructKey>(irFieldKey));
+ }
- addVarDecorations(context, irFieldKey, fieldDecl);
+ if (!irFieldKey)
+ {
+ irFieldKey = builder->createStructKey();
- addLinkageDecoration(context, irFieldKey, fieldDecl);
+ addNameHint(context, irFieldKey, fieldDecl);
+ addVarDecorations(context, irFieldKey, fieldDecl);
+ addLinkageDecoration(context, irFieldKey, fieldDecl);
+ }
if (auto semanticModifier = fieldDecl->findModifier<HLSLSimpleSemantic>())
{
@@ -7140,6 +7131,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
lowerRayPayloadAccessModifier(irFieldKey, writeModifier, kIROp_StageWriteAccessDecoration);
}
+ if (auto derivativeMemberModifier = fieldDecl->findModifier<DerivativeMemberAttribute>())
+ {
+ lowerDerivativeMemberModifier(irFieldKey, derivativeMemberModifier);
+ }
// We allow a field to be marked as a target intrinsic,
// so that we can override its mangled name in the
@@ -7815,6 +7810,16 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addJVPDerivativeMarkerDecoration(irFunc);
}
+ // Always force inline diff setter accessor to prevent downstream compiler from complaining
+ // fields are not fully initialized for the first `inout` parameter.
+ if (as<SetterDecl>(decl))
+ {
+ if (!decl->findModifier<ForceInlineAttribute>())
+ {
+ getBuilder()->addForceInlineDecoration(irFunc);
+ }
+ }
+
FuncDeclBaseTypeInfo info;
_lowerFuncDeclBaseTypeInfo(
subContext,