summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-04-21 18:57:39 -0400
committerGitHub <noreply@github.com>2023-04-21 15:57:39 -0700
commit385d3f4d29902242d7a803fb7b3d2a7513e4b5b1 (patch)
tree690728e166de4884251241a48cec0f5a404307b7 /source
parent957a4d3eb0a14a9d57bbb325ef0e1d458df2d2b9 (diff)
Add support for `kIROp_FloatCast` (#2824)
* Add support for `kIROp_FloatCast` * Update float-cast.slang * Added flag to d3d options
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h19
1 files changed, 19 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index c7ac8c357..55f042352 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1744,6 +1744,9 @@ struct DiffTransposePass
case kIROp_UpdateElement:
return transposeUpdateElement(builder, fwdInst, revValue);
+ case kIROp_FloatCast:
+ return transposeFloatCast(builder, fwdInst, revValue);
+
case kIROp_LoadReverseGradient:
case kIROp_ReverseGradientDiffPairRef:
case kIROp_DefaultConstruct:
@@ -2228,6 +2231,22 @@ struct DiffTransposePass
return TranspositionResult(gradients);
}
+ TranspositionResult transposeFloatCast(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
+ {
+ // (A = cast<T, U>(B)) -> (dB += cast<U, T>(dA))
+ return TranspositionResult(
+ List<RevGradient>(
+ RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdInst->getOperand(0),
+ builder->emitIntrinsicInst(
+ fwdInst->getOperand(0)->getDataType(),
+ kIROp_FloatCast,
+ 1,
+ &revValue),
+ fwdInst)));
+ }
+
// Gather all reverse-mode gradients for a Load inst, aggregate them and store them in the ptr.
//
void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad)