summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-15 23:26:14 -0700
committerGitHub <noreply@github.com>2023-03-15 23:26:14 -0700
commitf7431f96e1cad2a68534bebc1f25cd6f65f87f82 (patch)
treea9171c4705f53f246f8a75d5ced752d6538a448c
parent3a1a6c47cf7f24b09016e08b7cd7e2863911b08d (diff)
Fix `transcribeConstruct` for `makeStruct`. (#2703)
Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp61
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h1
2 files changed, 61 insertions, 1 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 247c3ddde..e9c156055 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -346,6 +346,64 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst*
}
}
+InstPair ForwardDiffTranscriber::transcribeMakeStruct(IRBuilder* builder, IRInst* origMakeStruct)
+{
+ IRInst* primalMakeStruct = maybeCloneForPrimalInst(builder, origMakeStruct);
+
+ // Check if the output type can be differentiated. If it cannot be
+ // differentiated, don't differentiate the inst
+ //
+ auto primalStructType = (IRType*)findOrTranscribePrimalInst(builder, origMakeStruct->getDataType());
+ if (auto diffStructType = differentiateType(builder, primalStructType))
+ {
+ auto primalStruct = as<IRStructType>(getResolvedInstForDecorations(primalStructType));
+ SLANG_RELEASE_ASSERT(primalStruct);
+
+ List<IRInst*> diffOperands;
+ UIndex ii = 0;
+ for (auto field : primalStruct->getFields())
+ {
+ SLANG_RELEASE_ASSERT(ii < origMakeStruct->getOperandCount());
+
+ // If this field is not differentiable, skip the operand.
+ if (!field->getKey()->findDecoration<IRDerivativeMemberDecoration>())
+ {
+ ii++;
+ continue;
+ }
+
+ // If the operand has a differential version, replace the original with
+ // the differential. Otherwise, use a zero.
+ //
+ if (auto diffInst = lookupDiffInst(origMakeStruct->getOperand(ii), nullptr))
+ {
+ diffOperands.add(diffInst);
+ }
+ else
+ {
+ auto operandDataType = origMakeStruct->getOperand(ii)->getDataType();
+ auto diffOperandType = differentiateType(builder, operandDataType);
+ SLANG_RELEASE_ASSERT(diffOperandType);
+ operandDataType = (IRType*)findOrTranscribePrimalInst(builder, operandDataType);
+ diffOperands.add(getDifferentialZeroOfType(builder, operandDataType));
+ }
+ ii++;
+ }
+
+ return InstPair(
+ primalMakeStruct,
+ builder->emitIntrinsicInst(
+ diffStructType,
+ kIROp_MakeStruct,
+ diffOperands.getCount(),
+ diffOperands.getBuffer()));
+ }
+ else
+ {
+ return InstPair(primalMakeStruct, nullptr);
+ }
+}
+
static bool _isDifferentiableFunc(IRInst* func)
{
func = getResolvedInstForDecorations(func);
@@ -1551,10 +1609,11 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_IntCast:
case kIROp_FloatCast:
case kIROp_MakeVectorFromScalar:
- case kIROp_MakeStruct:
case kIROp_MakeArray:
case kIROp_MakeArrayFromElement:
return transcribeConstruct(builder, origInst);
+ case kIROp_MakeStruct:
+ return transcribeMakeStruct(builder, origInst);
case kIROp_LookupWitness:
return transcribeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst));
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index 6032c2319..e9774be49 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -46,6 +46,7 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
// and then return nullptr. Literals do not need to be differentiated.
//
InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct);
+ InstPair transcribeMakeStruct(IRBuilder* builder, IRInst* origMakeStruct);
// Differentiating a call instruction here is primarily about generating
// an appropriate call list based on whichever parameters have differentials