From 61eb17b0b556ccc06f65f921bb0a4ea2784c4e20 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 26 Apr 2023 20:04:36 -0400 Subject: Add support for `kIROp_DefaultConstruct` (#2845) --- source/slang/slang-ir-autodiff-fwd.cpp | 16 ++++++++++++++++ source/slang/slang-ir-autodiff-fwd.h | 2 ++ 2 files changed, 18 insertions(+) (limited to 'source') diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 6a87550d2..e0b916090 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1380,6 +1380,19 @@ InstPair ForwardDiffTranscriber::transcribeMakeExistential(IRBuilder* builder, I return InstPair(primalResult, diffResult); } +InstPair ForwardDiffTranscriber::transcribeDefaultConstruct(IRBuilder* builder, IRInst* origInst) +{ + IRInst* primalConstruct = maybeCloneForPrimalInst(builder, origInst); + + IRInst* diffConstruct = nullptr; + + if (auto diffType = differentiateType(builder, origInst->getDataType())) + { + diffConstruct = builder->emitDefaultConstructRaw(diffType); + } + return InstPair(primalConstruct, diffConstruct); +} + InstPair ForwardDiffTranscriber::transcribeWrapExistential(IRBuilder* builder, IRInst* origInst) { auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origInst->getDataType()); @@ -1813,6 +1826,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_WrapExistential: return transcribeWrapExistential(builder, origInst); + case kIROp_DefaultConstruct: + return transcribeDefaultConstruct(builder, origInst); + case kIROp_undefined: return transcribeUndefined(builder, origInst); diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 91193edc1..4edb9301a 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -86,6 +86,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeWrapExistential(IRBuilder* builder, IRInst* origInst); + InstPair transcribeDefaultConstruct(IRBuilder* builder, IRInst* origInst); + virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override; void generateTrivialFwdDiffFunc(IRFunc* primalFunc, IRFunc* diffFunc); -- cgit v1.2.3