diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 30 |
1 files changed, 19 insertions, 11 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 0302d9ce7..e146ac3e0 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -955,20 +955,28 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig InstPair ForwardDiffTranscriber::transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle) { IRInst* primalSwizzle = maybeCloneForPrimalInst(builder, origSwizzle); - if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr)) { - List<IRInst*> swizzleIndices; - for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++) - swizzleIndices.add(origSwizzle->getElementIndex(ii)); + // `diffBase` may exist even if the type is non-differentiable (e.g. IRCall inst that + // creates other differentiable outputs). + // + // We'll check to see if we can get a differential for the type in order to determine + // whether to generate a differential swizzle inst. + // + if (auto diffType = differentiateType(builder, primalSwizzle->getDataType())) + { + List<IRInst*> swizzleIndices; + for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++) + swizzleIndices.add(origSwizzle->getElementIndex(ii)); - return InstPair( - primalSwizzle, - builder->emitSwizzle( - differentiateType(builder, primalSwizzle->getDataType()), - diffBase, - origSwizzle->getElementCount(), - swizzleIndices.getBuffer())); + return InstPair( + primalSwizzle, + builder->emitSwizzle( + diffType, + diffBase, + origSwizzle->getElementCount(), + swizzleIndices.getBuffer())); + } } return InstPair(primalSwizzle, nullptr); |
