summaryrefslogtreecommitdiff
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h19
1 files changed, 19 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index c7ac8c357..55f042352 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1744,6 +1744,9 @@ struct DiffTransposePass
case kIROp_UpdateElement:
return transposeUpdateElement(builder, fwdInst, revValue);
+ case kIROp_FloatCast:
+ return transposeFloatCast(builder, fwdInst, revValue);
+
case kIROp_LoadReverseGradient:
case kIROp_ReverseGradientDiffPairRef:
case kIROp_DefaultConstruct:
@@ -2228,6 +2231,22 @@ struct DiffTransposePass
return TranspositionResult(gradients);
}
+ TranspositionResult transposeFloatCast(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
+ {
+ // (A = cast<T, U>(B)) -> (dB += cast<U, T>(dA))
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdInst->getOperand(0),
+ builder->emitIntrinsicInst(
+ fwdInst->getOperand(0)->getDataType(),
+ kIROp_FloatCast,
+ 1,
+ &revValue),
+ fwdInst)));
+ }
+
// Gather all reverse-mode gradients for a Load inst, aggregate them and store them in the ptr.
//
void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad)