summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-03-17 08:05:14 -0700
committerGitHub <noreply@github.com>2025-03-17 15:05:14 +0000
commit714ee76af46b96c32724f0d6edb159fddeffc6bf (patch)
tree3ac6fc10580acd4cf250f5439c8d88aa1457fb6e /source
parent98ff41989b04ce883e9dc9f4464c45290d30c560 (diff)
Fix crash when swizzling non-differentiable types (#6613)
* Fix crash when swizzling non-differentiable types * Update slang-ir-autodiff-fwd.cpp
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp30
1 files changed, 19 insertions, 11 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 0302d9ce7..e146ac3e0 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -955,20 +955,28 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
InstPair ForwardDiffTranscriber::transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle)
{
IRInst* primalSwizzle = maybeCloneForPrimalInst(builder, origSwizzle);
-
if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr))
{
- List<IRInst*> swizzleIndices;
- for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++)
- swizzleIndices.add(origSwizzle->getElementIndex(ii));
+ // `diffBase` may exist even if the type is non-differentiable (e.g. IRCall inst that
+ // creates other differentiable outputs).
+ //
+ // We'll check to see if we can get a differential for the type in order to determine
+ // whether to generate a differential swizzle inst.
+ //
+ if (auto diffType = differentiateType(builder, primalSwizzle->getDataType()))
+ {
+ List<IRInst*> swizzleIndices;
+ for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++)
+ swizzleIndices.add(origSwizzle->getElementIndex(ii));
- return InstPair(
- primalSwizzle,
- builder->emitSwizzle(
- differentiateType(builder, primalSwizzle->getDataType()),
- diffBase,
- origSwizzle->getElementCount(),
- swizzleIndices.getBuffer()));
+ return InstPair(
+ primalSwizzle,
+ builder->emitSwizzle(
+ diffType,
+ diffBase,
+ origSwizzle->getElementCount(),
+ swizzleIndices.getBuffer()));
+ }
}
return InstPair(primalSwizzle, nullptr);