diff options
| author | Yong He <yonghe@outlook.com> | 2023-11-02 14:54:22 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-11-02 14:54:22 -0700 |
| commit | 911a4401b08f6199e18b32349c236c186a2dd128 (patch) | |
| tree | 75cd31ceb7a1c134f41cc8c44a08cd9123c27613 | |
| parent | 72e95f2c62b39ef1ddb6c169a9452a3b4fcb22a5 (diff) | |
Fix crash when writing to `no_diff` out parameter. (#3308)
* Fix crash when writing to `no_diff` out parameter.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 76 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 2 | ||||
| -rw-r--r-- | tests/autodiff/no-diff-out.slang | 30 |
6 files changed, 52 insertions, 89 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 8cf7d3e78..213b53df0 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -841,16 +841,18 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns case kIROp_unconditionalBranch: case kIROp_loop: auto origBranch = as<IRUnconditionalBranch>(origInst); + auto targetBlock = origBranch->getTargetBlock(); // Grab the differentials for any phi nodes. List<IRInst*> newArgs; for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++) { + auto origParam = getParamAt(targetBlock, ii); auto origArg = origBranch->getArg(ii); auto primalArg = lookupPrimalInst(builder, origArg); newArgs.add(primalArg); - if (differentiateType(builder, origArg->getDataType())) + if (differentiateType(builder, origParam->getDataType())) { auto diffArg = lookupDiffInst(origArg, nullptr); if (diffArg) @@ -869,6 +871,7 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns auto breakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); auto continueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock()); List<IRInst*> operands; + operands.add(diffBlock); operands.add(breakBlock); operands.add(continueBlock); operands.addRange(newArgs); @@ -877,6 +880,8 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns kIROp_loop, operands.getCount(), operands.getBuffer()); + if (auto maxItersDecoration = origLoop->findDecoration<IRLoopMaxItersDecoration>()) + builder->addLoopMaxItersDecoration(diffBranch, maxItersDecoration->getMaxIters()); } else { @@ -1158,71 +1163,6 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI return InstPair(primalUpdateField, diffUpdateElement); } -List<IRInst*> ForwardDiffTranscriber::transcribePhiArgs(IRBuilder* builder, List<IRInst*> origPhiArgs) -{ - // Grab the differentials for any phi nodes. - List<IRInst*> newArgs; - for (auto origArg : origPhiArgs) - { - auto primalArg = lookupPrimalInst(builder, origArg); - newArgs.add(primalArg); - - if (differentiateType(builder, origArg->getDataType())) - { - auto diffArg = lookupDiffInst(origArg, nullptr); - if (diffArg) - newArgs.add(diffArg); - else - newArgs.add( - getDifferentialZeroOfType(builder, origArg->getDataType())); - } - } - - return newArgs; -} - -InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* origLoop) -{ - // The loop comes with three blocks.. we just need to transcribe each one - // and assemble the new loop instruction. - - // Transcribe the target block (this is the 'condition' part of the loop, which - // will branch into the loop body) - auto diffTargetBlock = findOrTranscribeDiffInst(builder, origLoop->getTargetBlock()); - - // Transcribe the continue block (this is the 'update' part of the loop, which will - // branch into the condition block) - auto diffContinueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock()); - - // Transcribe the break block (this is the block after the exiting the loop) - auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); - - List<IRInst*> diffLoopOperands; - diffLoopOperands.add(diffTargetBlock); - diffLoopOperands.add(diffBreakBlock); - diffLoopOperands.add(diffContinueBlock); - - List<IRInst*> phiArgs; - for (UIndex ii = diffLoopOperands.getCount(); ii < origLoop->getOperandCount(); ii++) - phiArgs.add(origLoop->getOperand(ii)); - - auto newPhiArgs = transcribePhiArgs(builder, phiArgs); - for (auto newArg : newPhiArgs) - diffLoopOperands.add(newArg); - - IRInst* diffLoop = builder->emitIntrinsicInst( - nullptr, - kIROp_loop, - diffLoopOperands.getCount(), - diffLoopOperands.getBuffer()); - builder->markInstAsMixedDifferential(diffLoop); - - if (auto maxItersDecoration = origLoop->findDecoration<IRLoopMaxItersDecoration>()) - builder->addLoopMaxItersDecoration(diffLoop, maxItersDecoration->getMaxIters()); - - return InstPair(diffLoop, diffLoop); -} - InstPair ForwardDiffTranscriber::transcribeSwitch(IRBuilder* builder, IRSwitch* origSwitch) { // Transcribe condition (primal only, conditions do not produce differentials) @@ -1858,6 +1798,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* return transcribeUpdateElement(builder, origInst); case kIROp_unconditionalBranch: + case kIROp_loop: return transcribeControlFlow(builder, origInst); case kIROp_FloatLit: @@ -1875,9 +1816,6 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_GetElement: case kIROp_GetElementPtr: return transcribeGetElement(builder, origInst); - - case kIROp_loop: - return transcribeLoop(builder, as<IRLoop>(origInst)); case kIROp_ifElse: return transcribeIfElse(builder, as<IRIfElse>(origInst)); diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 8d8d65c10..f88235558 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -70,8 +70,6 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst); - InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop); - InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse); InstPair transcribeSwitch(IRBuilder* builder, IRSwitch* origSwitch); @@ -97,14 +95,9 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase // Transcribe a function definition. InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc); - // Transcribe a generic definition - InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric); - // Transcribe a function without marking the result as a decoration. IRFunc* transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc); - List<IRInst*> transcribePhiArgs(IRBuilder* builder, List<IRInst*> origPhiArgs); - void checkAutodiffInstDecorations(IRFunc* fwdFunc); SlangResult prepareFuncForForwardDiff(IRFunc* func); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index a3a4eb2b3..2283ebf5c 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -739,19 +739,6 @@ struct DiffTransposePass return false; } - - IRParam* getParamAt(IRBlock* block, UIndex ii) - { - UIndex index = 0; - for (auto param : block->getParams()) - { - if (ii == index) - return param; - - index ++; - } - SLANG_UNEXPECTED("ii >= paramCount"); - } void transposeBlock(IRBlock* fwdBlock, IRBlock* revBlock) { diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 073b8bf96..6afb6b719 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1183,6 +1183,19 @@ void hoistInstOutOfASMBlocks(IRBlock* block) } } +IRParam* getParamAt(IRBlock* block, UIndex ii) +{ + UIndex index = 0; + for (auto param : block->getParams()) + { + if (ii == index) + return param; + + index++; + } + SLANG_UNEXPECTED("ii >= paramCount"); +} + UnownedStringSlice getBasicTypeNameHint(IRType* basicType) { switch (basicType->getOp()) diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 0b377a3d1..096b008d9 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -300,6 +300,8 @@ inline bool isCompositeType(IRType* type) } } +IRParam* getParamAt(IRBlock* block, UIndex ii); + } #endif diff --git a/tests/autodiff/no-diff-out.slang b/tests/autodiff/no-diff-out.slang new file mode 100644 index 000000000..c8085a05f --- /dev/null +++ b/tests/autodiff/no-diff-out.slang @@ -0,0 +1,30 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; + +[Differentiable] +float test(float x, no_diff out float y) +{ + if (x == 1.0) + y = 0.0; + return x * x; +} + +[Differentiable] +float caller(float x, no_diff out float y) +{ + return test(x, y); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var p = diffPair(3.0, 0.0); + bwd_diff(caller)(p, 1.0); + outputBuffer[dispatchThreadID.x] = p.d; + // CHECK: 6.0 +} |
