summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-11-02 14:54:22 -0700
committerGitHub <noreply@github.com>2023-11-02 14:54:22 -0700
commit911a4401b08f6199e18b32349c236c186a2dd128 (patch)
tree75cd31ceb7a1c134f41cc8c44a08cd9123c27613
parent72e95f2c62b39ef1ddb6c169a9452a3b4fcb22a5 (diff)
Fix crash when writing to `no_diff` out parameter. (#3308)
* Fix crash when writing to `no_diff` out parameter. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp76
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h7
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h13
-rw-r--r--source/slang/slang-ir-util.cpp13
-rw-r--r--source/slang/slang-ir-util.h2
-rw-r--r--tests/autodiff/no-diff-out.slang30
6 files changed, 52 insertions, 89 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 8cf7d3e78..213b53df0 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -841,16 +841,18 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns
case kIROp_unconditionalBranch:
case kIROp_loop:
auto origBranch = as<IRUnconditionalBranch>(origInst);
+ auto targetBlock = origBranch->getTargetBlock();
// Grab the differentials for any phi nodes.
List<IRInst*> newArgs;
for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++)
{
+ auto origParam = getParamAt(targetBlock, ii);
auto origArg = origBranch->getArg(ii);
auto primalArg = lookupPrimalInst(builder, origArg);
newArgs.add(primalArg);
- if (differentiateType(builder, origArg->getDataType()))
+ if (differentiateType(builder, origParam->getDataType()))
{
auto diffArg = lookupDiffInst(origArg, nullptr);
if (diffArg)
@@ -869,6 +871,7 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns
auto breakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock());
auto continueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock());
List<IRInst*> operands;
+ operands.add(diffBlock);
operands.add(breakBlock);
operands.add(continueBlock);
operands.addRange(newArgs);
@@ -877,6 +880,8 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns
kIROp_loop,
operands.getCount(),
operands.getBuffer());
+ if (auto maxItersDecoration = origLoop->findDecoration<IRLoopMaxItersDecoration>())
+ builder->addLoopMaxItersDecoration(diffBranch, maxItersDecoration->getMaxIters());
}
else
{
@@ -1158,71 +1163,6 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI
return InstPair(primalUpdateField, diffUpdateElement);
}
-List<IRInst*> ForwardDiffTranscriber::transcribePhiArgs(IRBuilder* builder, List<IRInst*> origPhiArgs)
-{
- // Grab the differentials for any phi nodes.
- List<IRInst*> newArgs;
- for (auto origArg : origPhiArgs)
- {
- auto primalArg = lookupPrimalInst(builder, origArg);
- newArgs.add(primalArg);
-
- if (differentiateType(builder, origArg->getDataType()))
- {
- auto diffArg = lookupDiffInst(origArg, nullptr);
- if (diffArg)
- newArgs.add(diffArg);
- else
- newArgs.add(
- getDifferentialZeroOfType(builder, origArg->getDataType()));
- }
- }
-
- return newArgs;
-}
-
-InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* origLoop)
-{
- // The loop comes with three blocks.. we just need to transcribe each one
- // and assemble the new loop instruction.
-
- // Transcribe the target block (this is the 'condition' part of the loop, which
- // will branch into the loop body)
- auto diffTargetBlock = findOrTranscribeDiffInst(builder, origLoop->getTargetBlock());
-
- // Transcribe the continue block (this is the 'update' part of the loop, which will
- // branch into the condition block)
- auto diffContinueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock());
-
- // Transcribe the break block (this is the block after the exiting the loop)
- auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock());
-
- List<IRInst*> diffLoopOperands;
- diffLoopOperands.add(diffTargetBlock);
- diffLoopOperands.add(diffBreakBlock);
- diffLoopOperands.add(diffContinueBlock);
-
- List<IRInst*> phiArgs;
- for (UIndex ii = diffLoopOperands.getCount(); ii < origLoop->getOperandCount(); ii++)
- phiArgs.add(origLoop->getOperand(ii));
-
- auto newPhiArgs = transcribePhiArgs(builder, phiArgs);
- for (auto newArg : newPhiArgs)
- diffLoopOperands.add(newArg);
-
- IRInst* diffLoop = builder->emitIntrinsicInst(
- nullptr,
- kIROp_loop,
- diffLoopOperands.getCount(),
- diffLoopOperands.getBuffer());
- builder->markInstAsMixedDifferential(diffLoop);
-
- if (auto maxItersDecoration = origLoop->findDecoration<IRLoopMaxItersDecoration>())
- builder->addLoopMaxItersDecoration(diffLoop, maxItersDecoration->getMaxIters());
-
- return InstPair(diffLoop, diffLoop);
-}
-
InstPair ForwardDiffTranscriber::transcribeSwitch(IRBuilder* builder, IRSwitch* origSwitch)
{
// Transcribe condition (primal only, conditions do not produce differentials)
@@ -1858,6 +1798,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
return transcribeUpdateElement(builder, origInst);
case kIROp_unconditionalBranch:
+ case kIROp_loop:
return transcribeControlFlow(builder, origInst);
case kIROp_FloatLit:
@@ -1875,9 +1816,6 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_GetElement:
case kIROp_GetElementPtr:
return transcribeGetElement(builder, origInst);
-
- case kIROp_loop:
- return transcribeLoop(builder, as<IRLoop>(origInst));
case kIROp_ifElse:
return transcribeIfElse(builder, as<IRIfElse>(origInst));
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index 8d8d65c10..f88235558 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -70,8 +70,6 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
InstPair transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst);
- InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop);
-
InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse);
InstPair transcribeSwitch(IRBuilder* builder, IRSwitch* origSwitch);
@@ -97,14 +95,9 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
// Transcribe a function definition.
InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc);
- // Transcribe a generic definition
- InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric);
-
// Transcribe a function without marking the result as a decoration.
IRFunc* transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc);
- List<IRInst*> transcribePhiArgs(IRBuilder* builder, List<IRInst*> origPhiArgs);
-
void checkAutodiffInstDecorations(IRFunc* fwdFunc);
SlangResult prepareFuncForForwardDiff(IRFunc* func);
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index a3a4eb2b3..2283ebf5c 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -739,19 +739,6 @@ struct DiffTransposePass
return false;
}
-
- IRParam* getParamAt(IRBlock* block, UIndex ii)
- {
- UIndex index = 0;
- for (auto param : block->getParams())
- {
- if (ii == index)
- return param;
-
- index ++;
- }
- SLANG_UNEXPECTED("ii >= paramCount");
- }
void transposeBlock(IRBlock* fwdBlock, IRBlock* revBlock)
{
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 073b8bf96..6afb6b719 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -1183,6 +1183,19 @@ void hoistInstOutOfASMBlocks(IRBlock* block)
}
}
+IRParam* getParamAt(IRBlock* block, UIndex ii)
+{
+ UIndex index = 0;
+ for (auto param : block->getParams())
+ {
+ if (ii == index)
+ return param;
+
+ index++;
+ }
+ SLANG_UNEXPECTED("ii >= paramCount");
+}
+
UnownedStringSlice getBasicTypeNameHint(IRType* basicType)
{
switch (basicType->getOp())
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 0b377a3d1..096b008d9 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -300,6 +300,8 @@ inline bool isCompositeType(IRType* type)
}
}
+IRParam* getParamAt(IRBlock* block, UIndex ii);
+
}
#endif
diff --git a/tests/autodiff/no-diff-out.slang b/tests/autodiff/no-diff-out.slang
new file mode 100644
index 000000000..c8085a05f
--- /dev/null
+++ b/tests/autodiff/no-diff-out.slang
@@ -0,0 +1,30 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+
+[Differentiable]
+float test(float x, no_diff out float y)
+{
+ if (x == 1.0)
+ y = 0.0;
+ return x * x;
+}
+
+[Differentiable]
+float caller(float x, no_diff out float y)
+{
+ return test(x, y);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ var p = diffPair(3.0, 0.0);
+ bwd_diff(caller)(p, 1.0);
+ outputBuffer[dispatchThreadID.x] = p.d;
+ // CHECK: 6.0
+}