summaryrefslogtreecommitdiff
path: root/source/slang/slang-lower-to-ir.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-07-12 18:02:36 -0400
committerGitHub <noreply@github.com>2023-07-12 18:02:36 -0400
commitbbd9c2e6d7b57f5acc3238083ab2f7c7b140df5e (patch)
treeb6abec59b4ff3fe92436db35c1e61a6df236f550 /source/slang/slang-lower-to-ir.cpp
parent4ed3aafa20b667329f2f9dea94d7c65dc2e80db4 (diff)
Extend `no_diff` to support subscript operations on resources and array variables… (#2981)
* Extend `no_diff` to support subscript operations on resources and array variables * Update autodiff.slang.expected
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
-rw-r--r--source/slang/slang-lower-to-ir.cpp42
1 files changed, 39 insertions, 3 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index d869bf60e..a1c7a2b8e 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3333,9 +3333,45 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
LoweredValInfo visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr)
{
auto baseVal = lowerSubExpr(expr->innerExpr);
- SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
- getBuilder()->addDecoration(baseVal.val, kIROp_TreatAsDifferentiableDecoration);
- return baseVal;
+
+ IRInst* innerInst = nullptr;
+ if (baseVal.flavor != LoweredValInfo::Flavor::Simple)
+ {
+ if (!isLValueContext())
+ {
+ auto materializedVal = materialize(context, baseVal);
+
+ // TODO(Sai): We might be missing the case where a single materialize could create
+ // multiple calls (multiple index operations?). Not quite sure what the right way
+ // to handle that case might be.
+ //
+ if (as<IRCall>(materializedVal.val))
+ getBuilder()->addDecoration(materializedVal.val, kIROp_TreatAsDifferentiableDecoration);
+
+ innerInst = getSimpleVal(context, materializedVal);
+
+ // We'll special case handle 'loads' here in order to allow TreatAsDifferentiable to be
+ // used on array index operations. (This is to avoid a discrepancy between using no_diff
+ // on local variable indexing vs. resource indexing.)
+ //
+ if (as<IRLoad>(innerInst))
+ innerInst = getBuilder()->emitDetachDerivative(innerInst->getDataType(), innerInst);
+ }
+ else
+ {
+ SLANG_ASSERT("TreatAsDifferentiableExpr on non-simple l-values not properly defined.");
+ }
+ }
+ else
+ {
+ if (as<IRCall>(baseVal.val))
+ getBuilder()->addDecoration(baseVal.val, kIROp_TreatAsDifferentiableDecoration);
+ innerInst = baseVal.val;
+ }
+
+ SLANG_ASSERT(innerInst);
+
+ return LoweredValInfo::simple(innerInst);
}
// Emit IR to denote the forward-mode derivative