summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.cpp103
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.h1
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp16
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp6
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h11
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp5
-rw-r--r--tests/autodiff/reverse-multi-return.slang50
-rw-r--r--tests/autodiff/reverse-multi-return.slang.expected.txt7
-rw-r--r--tests/autodiff/reverse-single-iter-loop.slang8
10 files changed, 139 insertions, 72 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp
index 4e0a413db..2b201466b 100644
--- a/source/slang/slang-ir-autodiff-cfg-norm.cpp
+++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp
@@ -97,14 +97,6 @@ struct CFGNormalizationPass
builder->setInsertInto(afterBlock);
unreachInst->removeAndDeallocate();
- /*
- HashSet<IRBlock*> predecessorSet;
- for (auto predecessor : parentAfterBlock->getPredecessors())
- predecessorSet.Add(predecessor);
-
- SLANG_ASSERT(predecessorSet.Count() <= 1);
- */
-
builder->emitBranch(parentAfterBlock);
}
}
@@ -169,6 +161,45 @@ struct CFGNormalizationPass
IRBlock* parentAfterBlock = afterBlocks[0];
+ auto addBreakBypassBranch = [&](IRBlock* block)
+ {
+ // We could arrive at the after-block before or
+ // after encountering a break statement.
+ // To handle this, we'll split the flow by checking the break flag
+ //
+ builder.setInsertAfter(block);
+
+ auto preAfterSplitBlock = builder.emitBlock();
+ preAfterSplitBlock->insertBefore(block);
+
+ auto afterSplitBlock = builder.emitBlock();
+ afterSplitBlock->insertBefore(block);
+
+ block->replaceUsesWith(preAfterSplitBlock);
+
+ builder.setInsertInto(preAfterSplitBlock);
+ builder.emitBranch(afterSplitBlock);
+
+ // Converging block for the split that we're making.
+ auto afterSplitAfterBlock = builder.emitBlock();
+
+ builder.setInsertInto(afterSplitBlock);
+ auto breakFlagValue = builder.emitLoad(parentRegion->breakVar);
+
+ builder.emitIfElse(
+ breakFlagValue,
+ block,
+ afterSplitAfterBlock,
+ afterSplitAfterBlock);
+
+ // At this point, we need to place afterSplitAfterBlock between
+ // at the _end_ of this region, but we aren't there yet (and
+ // don't know which block is the end of this region)
+ // Therefore, we'll defer this step and add it to a list for later.
+ //
+ pendingAfterBlocks.add(afterSplitAfterBlock);
+ };
+
// Follow this thread of execution till we hit an
// acceptable after block.
//
@@ -210,12 +241,15 @@ struct CFGNormalizationPass
auto afterBlock = ifElse->getAfterBlock();
// Trivial case, both end-points branch into the after block
- if (trueTargetBlock == afterBlock &&
+ /*if (trueTargetBlock == afterBlock &&
falseTargetBlock == afterBlock)
{
+ if ()
+ addBreakBypassBranch(afterBlock);
currentBlock = afterBlock;
+ // TODO: Need to split block.
break;
- }
+ }*/
auto afterBreakRegion = false;
auto afterBaseRegion = false;
@@ -281,41 +315,7 @@ struct CFGNormalizationPass
// Do we need to split the after region?
if (afterBaseRegion && afterBreakRegion)
{
- // We could arrive at the after-block before or
- // after encountering a break statement.
- // To handle this, we'll split the flow by checking the break flag
- //
- builder.setInsertAfter(afterBlock);
-
- auto preAfterSplitBlock = builder.emitBlock();
- preAfterSplitBlock->insertBefore(afterBlock);
-
- auto afterSplitBlock = builder.emitBlock();
- afterSplitBlock->insertBefore(afterBlock);
-
- afterBlock->replaceUsesWith(preAfterSplitBlock);
-
- builder.setInsertInto(preAfterSplitBlock);
- builder.emitBranch(afterSplitBlock);
-
- // Converging block for the split that we're making.
- auto afterSplitAfterBlock = builder.emitBlock();
-
- builder.setInsertInto(afterSplitBlock);
- auto breakFlagValue = builder.emitLoad(parentRegion->breakVar);
-
- builder.emitIfElse(
- breakFlagValue,
- afterBlock,
- afterSplitAfterBlock,
- afterSplitAfterBlock);
-
- // At this point, we need to place afterSplitAfterBlock between
- // at the _end_ of this region, but we aren't there yet (and
- // don't know which block is the end of this region)
- // Therefore, we'll defer this step and add it to a list for later.
- //
- pendingAfterBlocks.add(afterSplitAfterBlock);
+ addBreakBypassBranch(afterBlock);
// Update current block.
currentBlock = afterBlock;
@@ -419,12 +419,6 @@ struct CFGNormalizationPass
if (isLoopTrivial(as<IRLoop>(branchInst)))
{
auto firstLoopBlock = as<IRLoop>(branchInst)->getTargetBlock();
- auto terminator = firstLoopBlock->getTerminator();
-
- // We really shouldn't see a conditional branch on a trivial loop
- // but if we hit this assert, handle this case.
- //
- SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(terminator));
// Normalize the region from the first loop block till break.
auto preBreakEndPoint = getNormalizedRegionEndpoint(
@@ -583,6 +577,7 @@ struct CFGNormalizationPass
};
void normalizeCFG(
+ SharedIRBuilder* sharedBuilder,
IRGlobalValueWithCode* func,
IRCFGNormalizationPass const& options)
{
@@ -591,9 +586,7 @@ void normalizeCFG(
//
eliminatePhisInFunc(LivenessMode::Disabled, func->getModule(), func);
- SharedIRBuilder sharedBuilder(func->getModule());
- sharedBuilder.deduplicateAndRebuildGlobalNumberingMap();
- CFGNormalizationContext context = {&sharedBuilder, options.sink};
+ CFGNormalizationContext context = {sharedBuilder, options.sink};
CFGNormalizationPass cfgPass(context);
List<IRBlock*> workList;
@@ -622,7 +615,7 @@ void normalizeCFG(
}
disableIRValidationAtInsert();
- constructSSA(&sharedBuilder, func);
+ constructSSA(sharedBuilder, func);
enableIRValidationAtInsert();
}
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.h b/source/slang/slang-ir-autodiff-cfg-norm.h
index 2a39f7695..f256d8ce8 100644
--- a/source/slang/slang-ir-autodiff-cfg-norm.h
+++ b/source/slang/slang-ir-autodiff-cfg-norm.h
@@ -19,6 +19,7 @@ namespace Slang
/// "after" block.
///
void normalizeCFG(
+ SharedIRBuilder* sharedBuilder,
IRGlobalValueWithCode* func,
IRCFGNormalizationPass const& options = IRCFGNormalizationPass());
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index abe3f718c..f60412efb 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -527,6 +527,9 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns
auto diffArg = lookupDiffInst(origArg, nullptr);
if (diffArg)
newArgs.add(diffArg);
+ else
+ newArgs.add(
+ getDifferentialZeroOfType(builder, origArg->getDataType()));
}
}
@@ -576,16 +579,15 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns
return InstPair(nullptr, nullptr);
}
-InstPair ForwardDiffTranscriber::transcribeConst(IRBuilder* builder, IRInst* origInst)
+InstPair ForwardDiffTranscriber::transcribeConst(IRBuilder*, IRInst* origInst)
{
switch(origInst->getOp())
{
case kIROp_FloatLit:
- return InstPair(origInst, builder->getFloatValue(origInst->getDataType(), 0.0f));
+ case kIROp_IntLit:
+ return InstPair(origInst, nullptr);
case kIROp_VoidLit:
return InstPair(origInst, origInst);
- case kIROp_IntLit:
- return InstPair(origInst, builder->getIntValue(origInst->getDataType(), 0));
}
getSink()->diagnose(
@@ -943,9 +945,15 @@ InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* build
SLANG_ASSERT(primalVal);
auto diffPrimalVal = findOrTranscribePrimalInst(builder, origInst->getDifferentialValue());
SLANG_ASSERT(diffPrimalVal);
+
auto primalDiffVal = findOrTranscribeDiffInst(builder, origInst->getPrimalValue());
+ if (!primalDiffVal)
+ primalDiffVal = getDifferentialZeroOfType(builder, origInst->getPrimalValue()->getDataType());
SLANG_ASSERT(primalDiffVal);
+
auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue());
+ if (!diffDiffVal)
+ diffDiffVal = getDifferentialZeroOfType(builder, origInst->getDifferentialValue()->getDataType());
SLANG_ASSERT(diffDiffVal);
auto primalPairType = findOrTranscribePrimalInst(builder, origInst->getFullType());
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 8f218293d..0f2ceceb4 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -510,7 +510,7 @@ namespace Slang
eliminateMultiLevelBreakForFunc(func->getModule(), func);
IRCFGNormalizationPass cfgPass = {this->getSink()};
- normalizeCFG(func);
+ normalizeCFG(autoDiffSharedContext->sharedBuilder, func);
AutoDiffAddressConversionPolicy cvtPolicty;
cvtPolicty.diffTypeContext = &diffTypeContext;
@@ -544,6 +544,8 @@ namespace Slang
// reversible.
if (SLANG_FAILED(prepareFuncForBackwardDiff(primalFunc)))
return diffPropagateFunc;
+
+ autoDiffSharedContext->sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
// Forward transcribe the clone of the original func.
ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>(
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 91374e006..520c6d276 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -911,12 +911,14 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
// Tag the differential inst using a decoration (if it doesn't have one)
if (!pair.differential->findDecoration<IRDifferentialInstDecoration>() &&
- !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>())
+ !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>() &&
+ !as<IRConstant>(pair.differential))
{
// TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential
// instead.
//
- builder->markInstAsDifferential(pair.differential, as<IRType>(pair.primal->getDataType()));
+ auto primalType = as<IRType>(pair.primal->getDataType());
+ builder->markInstAsDifferential(pair.differential, primalType);
}
break;
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index f87aa7751..5aad6e3a3 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1694,9 +1694,6 @@ struct DiffTransposePass
IRInst* newInst = builder->emitMakeVector(targetType, operands.getCount(), operands.getBuffer());
- if (isDifferentialInst(inst))
- builder->markInstAsDifferential(newInst);
-
return newInst;
}
@@ -1725,6 +1722,11 @@ struct DiffTransposePass
builder->setInsertAfter(operand);
IRInst* newOperand = promoteToType(builder, targetType, operand);
+
+ if (isDifferentialInst(operand))
+ builder->markInstAsDifferential(
+ newOperand, tryGetPrimalTypeFromDiffInst(fwdInst));
+
newOperands.add(newOperand);
needNewInst = true;
@@ -1747,7 +1749,8 @@ struct DiffTransposePass
builder->setInsertLoc(oldLoc);
if (isDifferentialInst(fwdInst))
- builder->markInstAsDifferential(newInst);
+ builder->markInstAsDifferential(
+ newInst, tryGetPrimalTypeFromDiffInst(fwdInst));
return newInst;
}
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 44cb2aa09..daf6e44d4 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -322,7 +322,10 @@ struct ExtractPrimalFuncContext
{
if (shouldStoreInst(inst))
{
- builder.setInsertAfter(inst);
+ if (as<IRParam>(inst))
+ builder.setInsertBefore(block->getFirstOrdinaryInst());
+ else
+ builder.setInsertAfter(inst);
storeInst(builder, inst, outIntermediary);
}
else if (inst->getOp() == kIROp_Var)
diff --git a/tests/autodiff/reverse-multi-return.slang b/tests/autodiff/reverse-multi-return.slang
new file mode 100644
index 000000000..ee8bb9a4c
--- /dev/null
+++ b/tests/autodiff/reverse-multi-return.slang
@@ -0,0 +1,50 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+typedef float.Differential dfloat;
+
+[BackwardDifferentiable]
+float test_multi_return(float y)
+{
+ if (y > 0.6)
+ {
+ if (y > 0.8)
+ {
+ return y * 10.0f;
+ }
+ else
+ {
+ return y * 4.0f;
+ }
+ }
+ return y * 6.0f;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(1.0, 0.0);
+
+ __bwd_diff(test_multi_return)(dpa, 1.0f);
+ outputBuffer[0] = dpa.d; // Expect: 10.0
+ }
+
+ {
+ dpfloat dpa = dpfloat(0.4, 0.0);
+
+ __bwd_diff(test_multi_return)(dpa, 1.0f);
+ outputBuffer[1] = dpa.d; // Expect: 6.0
+ }
+
+ {
+ dpfloat dpa = dpfloat(0.7, 0.0);
+
+ __bwd_diff(test_multi_return)(dpa, 1.0f);
+ outputBuffer[2] = dpa.d; // Expect: 4.0
+ }
+}
diff --git a/tests/autodiff/reverse-multi-return.slang.expected.txt b/tests/autodiff/reverse-multi-return.slang.expected.txt
new file mode 100644
index 000000000..115191b13
--- /dev/null
+++ b/tests/autodiff/reverse-multi-return.slang.expected.txt
@@ -0,0 +1,7 @@
+type: float
+10.000000
+6.000000
+4.000000
+0.000000
+0.000000
+0.000000
diff --git a/tests/autodiff/reverse-single-iter-loop.slang b/tests/autodiff/reverse-single-iter-loop.slang
index 47232147a..20c26e000 100644
--- a/tests/autodiff/reverse-single-iter-loop.slang
+++ b/tests/autodiff/reverse-single-iter-loop.slang
@@ -49,11 +49,9 @@ float test_nested_if_else_single_iter_loop(float y)
break;
}
}
- else
- {
- x = y * 6.0f;
- break;
- }
+
+ x = y * 6.0f;
+ break;
}
return x;