diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 2 |
2 files changed, 18 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 6a87550d2..e0b916090 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1380,6 +1380,19 @@ InstPair ForwardDiffTranscriber::transcribeMakeExistential(IRBuilder* builder, I return InstPair(primalResult, diffResult); } +InstPair ForwardDiffTranscriber::transcribeDefaultConstruct(IRBuilder* builder, IRInst* origInst) +{ + IRInst* primalConstruct = maybeCloneForPrimalInst(builder, origInst); + + IRInst* diffConstruct = nullptr; + + if (auto diffType = differentiateType(builder, origInst->getDataType())) + { + diffConstruct = builder->emitDefaultConstructRaw(diffType); + } + return InstPair(primalConstruct, diffConstruct); +} + InstPair ForwardDiffTranscriber::transcribeWrapExistential(IRBuilder* builder, IRInst* origInst) { auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origInst->getDataType()); @@ -1813,6 +1826,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_WrapExistential: return transcribeWrapExistential(builder, origInst); + case kIROp_DefaultConstruct: + return transcribeDefaultConstruct(builder, origInst); + case kIROp_undefined: return transcribeUndefined(builder, origInst); diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 91193edc1..4edb9301a 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -86,6 +86,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeWrapExistential(IRBuilder* builder, IRInst* origInst); + InstPair transcribeDefaultConstruct(IRBuilder* builder, IRInst* origInst); + virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override; void generateTrivialFwdDiffFunc(IRFunc* primalFunc, IRFunc* diffFunc); |
