diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 5 |
3 files changed, 6 insertions, 10 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index fa1ca8519..ef6178976 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -295,9 +295,6 @@ namespace Slang // Create an empty func to represent the transcribed func of `origFunc`. InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc) { - if (auto bwdDiffFunc = findExistingDiffFunc(origFunc)) - return InstPair(origFunc, bwdDiffFunc); - if (!isBackwardDifferentiableFunc(origFunc)) return InstPair(nullptr, nullptr); @@ -379,11 +376,15 @@ namespace Slang InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) { + if (auto bwdDiffFunc = findExistingDiffFunc(origFunc)) + return InstPair(origFunc, bwdDiffFunc); + auto header = transcribeFuncHeaderImpl(inBuilder, origFunc); if (!header.differential) return header; - + IRBuilder builder = *inBuilder; + builder.setInsertInto(header.differential); builder.emitBlock(); auto origFuncType = as<IRFuncType>(origFunc->getFullType()); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 7ab0ee692..a8e06bf91 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -879,7 +879,7 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst) { - // If a differential intstruction is already mapped for + // If a differential instruction is already mapped for // this original inst, return that. // if (auto diffInst = lookupDiffInst(origInst, nullptr)) diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 7bcd4c90b..f57fb2974 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -170,11 +170,6 @@ struct DifferentiableTypeConformanceContext { switch (origType->getOp()) { - case kIROp_FloatType: - case kIROp_HalfType: - case kIROp_DoubleType: - return origType; - case kIROp_ArrayType: { auto diffElementType = (IRType*)getDifferentialForType( |
