diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-07-12 18:02:36 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-07-12 18:02:36 -0400 |
| commit | bbd9c2e6d7b57f5acc3238083ab2f7c7b140df5e (patch) | |
| tree | b6abec59b4ff3fe92436db35c1e61a6df236f550 /source/slang/slang-lower-to-ir.cpp | |
| parent | 4ed3aafa20b667329f2f9dea94d7c65dc2e80db4 (diff) | |
Extend `no_diff` to support subscript operations on resources and array variables… (#2981)
* Extend `no_diff` to support subscript operations on resources and array variables
* Update autodiff.slang.expected
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 42 |
1 files changed, 39 insertions, 3 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index d869bf60e..a1c7a2b8e 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3333,9 +3333,45 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> LoweredValInfo visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) { auto baseVal = lowerSubExpr(expr->innerExpr); - SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); - getBuilder()->addDecoration(baseVal.val, kIROp_TreatAsDifferentiableDecoration); - return baseVal; + + IRInst* innerInst = nullptr; + if (baseVal.flavor != LoweredValInfo::Flavor::Simple) + { + if (!isLValueContext()) + { + auto materializedVal = materialize(context, baseVal); + + // TODO(Sai): We might be missing the case where a single materialize could create + // multiple calls (multiple index operations?). Not quite sure what the right way + // to handle that case might be. + // + if (as<IRCall>(materializedVal.val)) + getBuilder()->addDecoration(materializedVal.val, kIROp_TreatAsDifferentiableDecoration); + + innerInst = getSimpleVal(context, materializedVal); + + // We'll special case handle 'loads' here in order to allow TreatAsDifferentiable to be + // used on array index operations. (This is to avoid a discrepancy between using no_diff + // on local variable indexing vs. resource indexing.) + // + if (as<IRLoad>(innerInst)) + innerInst = getBuilder()->emitDetachDerivative(innerInst->getDataType(), innerInst); + } + else + { + SLANG_ASSERT("TreatAsDifferentiableExpr on non-simple l-values not properly defined."); + } + } + else + { + if (as<IRCall>(baseVal.val)) + getBuilder()->addDecoration(baseVal.val, kIROp_TreatAsDifferentiableDecoration); + innerInst = baseVal.val; + } + + SLANG_ASSERT(innerInst); + + return LoweredValInfo::simple(innerInst); } // Emit IR to denote the forward-mode derivative |
