summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-01-31 03:26:59 -0500
committerGitHub <noreply@github.com>2023-01-31 00:26:59 -0800
commite312d5c7dfde80941d96e522079a5d70f7d00649 (patch)
treecf600a7f49117a77336ad55e59816f5c323cd705 /source
parent77cdbb2101f4e27bf1800d4bc1077c0510668c25 (diff)
Patched support for multi-return and fallthrough if-else with break stmts (#2617)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.cpp103
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.h1
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp16
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp6
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h11
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp5
7 files changed, 79 insertions, 67 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp
index 4e0a413db..2b201466b 100644
--- a/source/slang/slang-ir-autodiff-cfg-norm.cpp
+++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp
@@ -97,14 +97,6 @@ struct CFGNormalizationPass
builder->setInsertInto(afterBlock);
unreachInst->removeAndDeallocate();
- /*
- HashSet<IRBlock*> predecessorSet;
- for (auto predecessor : parentAfterBlock->getPredecessors())
- predecessorSet.Add(predecessor);
-
- SLANG_ASSERT(predecessorSet.Count() <= 1);
- */
-
builder->emitBranch(parentAfterBlock);
}
}
@@ -169,6 +161,45 @@ struct CFGNormalizationPass
IRBlock* parentAfterBlock = afterBlocks[0];
+ auto addBreakBypassBranch = [&](IRBlock* block)
+ {
+ // We could arrive at the after-block before or
+ // after encountering a break statement.
+ // To handle this, we'll split the flow by checking the break flag
+ //
+ builder.setInsertAfter(block);
+
+ auto preAfterSplitBlock = builder.emitBlock();
+ preAfterSplitBlock->insertBefore(block);
+
+ auto afterSplitBlock = builder.emitBlock();
+ afterSplitBlock->insertBefore(block);
+
+ block->replaceUsesWith(preAfterSplitBlock);
+
+ builder.setInsertInto(preAfterSplitBlock);
+ builder.emitBranch(afterSplitBlock);
+
+ // Converging block for the split that we're making.
+ auto afterSplitAfterBlock = builder.emitBlock();
+
+ builder.setInsertInto(afterSplitBlock);
+ auto breakFlagValue = builder.emitLoad(parentRegion->breakVar);
+
+ builder.emitIfElse(
+ breakFlagValue,
+ block,
+ afterSplitAfterBlock,
+ afterSplitAfterBlock);
+
+ // At this point, we need to place afterSplitAfterBlock between
+ // at the _end_ of this region, but we aren't there yet (and
+ // don't know which block is the end of this region)
+ // Therefore, we'll defer this step and add it to a list for later.
+ //
+ pendingAfterBlocks.add(afterSplitAfterBlock);
+ };
+
// Follow this thread of execution till we hit an
// acceptable after block.
//
@@ -210,12 +241,15 @@ struct CFGNormalizationPass
auto afterBlock = ifElse->getAfterBlock();
// Trivial case, both end-points branch into the after block
- if (trueTargetBlock == afterBlock &&
+ /*if (trueTargetBlock == afterBlock &&
falseTargetBlock == afterBlock)
{
+ if ()
+ addBreakBypassBranch(afterBlock);
currentBlock = afterBlock;
+ // TODO: Need to split block.
break;
- }
+ }*/
auto afterBreakRegion = false;
auto afterBaseRegion = false;
@@ -281,41 +315,7 @@ struct CFGNormalizationPass
// Do we need to split the after region?
if (afterBaseRegion && afterBreakRegion)
{
- // We could arrive at the after-block before or
- // after encountering a break statement.
- // To handle this, we'll split the flow by checking the break flag
- //
- builder.setInsertAfter(afterBlock);
-
- auto preAfterSplitBlock = builder.emitBlock();
- preAfterSplitBlock->insertBefore(afterBlock);
-
- auto afterSplitBlock = builder.emitBlock();
- afterSplitBlock->insertBefore(afterBlock);
-
- afterBlock->replaceUsesWith(preAfterSplitBlock);
-
- builder.setInsertInto(preAfterSplitBlock);
- builder.emitBranch(afterSplitBlock);
-
- // Converging block for the split that we're making.
- auto afterSplitAfterBlock = builder.emitBlock();
-
- builder.setInsertInto(afterSplitBlock);
- auto breakFlagValue = builder.emitLoad(parentRegion->breakVar);
-
- builder.emitIfElse(
- breakFlagValue,
- afterBlock,
- afterSplitAfterBlock,
- afterSplitAfterBlock);
-
- // At this point, we need to place afterSplitAfterBlock between
- // at the _end_ of this region, but we aren't there yet (and
- // don't know which block is the end of this region)
- // Therefore, we'll defer this step and add it to a list for later.
- //
- pendingAfterBlocks.add(afterSplitAfterBlock);
+ addBreakBypassBranch(afterBlock);
// Update current block.
currentBlock = afterBlock;
@@ -419,12 +419,6 @@ struct CFGNormalizationPass
if (isLoopTrivial(as<IRLoop>(branchInst)))
{
auto firstLoopBlock = as<IRLoop>(branchInst)->getTargetBlock();
- auto terminator = firstLoopBlock->getTerminator();
-
- // We really shouldn't see a conditional branch on a trivial loop
- // but if we hit this assert, handle this case.
- //
- SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(terminator));
// Normalize the region from the first loop block till break.
auto preBreakEndPoint = getNormalizedRegionEndpoint(
@@ -583,6 +577,7 @@ struct CFGNormalizationPass
};
void normalizeCFG(
+ SharedIRBuilder* sharedBuilder,
IRGlobalValueWithCode* func,
IRCFGNormalizationPass const& options)
{
@@ -591,9 +586,7 @@ void normalizeCFG(
//
eliminatePhisInFunc(LivenessMode::Disabled, func->getModule(), func);
- SharedIRBuilder sharedBuilder(func->getModule());
- sharedBuilder.deduplicateAndRebuildGlobalNumberingMap();
- CFGNormalizationContext context = {&sharedBuilder, options.sink};
+ CFGNormalizationContext context = {sharedBuilder, options.sink};
CFGNormalizationPass cfgPass(context);
List<IRBlock*> workList;
@@ -622,7 +615,7 @@ void normalizeCFG(
}
disableIRValidationAtInsert();
- constructSSA(&sharedBuilder, func);
+ constructSSA(sharedBuilder, func);
enableIRValidationAtInsert();
}
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.h b/source/slang/slang-ir-autodiff-cfg-norm.h
index 2a39f7695..f256d8ce8 100644
--- a/source/slang/slang-ir-autodiff-cfg-norm.h
+++ b/source/slang/slang-ir-autodiff-cfg-norm.h
@@ -19,6 +19,7 @@ namespace Slang
/// "after" block.
///
void normalizeCFG(
+ SharedIRBuilder* sharedBuilder,
IRGlobalValueWithCode* func,
IRCFGNormalizationPass const& options = IRCFGNormalizationPass());
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index abe3f718c..f60412efb 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -527,6 +527,9 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns
auto diffArg = lookupDiffInst(origArg, nullptr);
if (diffArg)
newArgs.add(diffArg);
+ else
+ newArgs.add(
+ getDifferentialZeroOfType(builder, origArg->getDataType()));
}
}
@@ -576,16 +579,15 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns
return InstPair(nullptr, nullptr);
}
-InstPair ForwardDiffTranscriber::transcribeConst(IRBuilder* builder, IRInst* origInst)
+InstPair ForwardDiffTranscriber::transcribeConst(IRBuilder*, IRInst* origInst)
{
switch(origInst->getOp())
{
case kIROp_FloatLit:
- return InstPair(origInst, builder->getFloatValue(origInst->getDataType(), 0.0f));
+ case kIROp_IntLit:
+ return InstPair(origInst, nullptr);
case kIROp_VoidLit:
return InstPair(origInst, origInst);
- case kIROp_IntLit:
- return InstPair(origInst, builder->getIntValue(origInst->getDataType(), 0));
}
getSink()->diagnose(
@@ -943,9 +945,15 @@ InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* build
SLANG_ASSERT(primalVal);
auto diffPrimalVal = findOrTranscribePrimalInst(builder, origInst->getDifferentialValue());
SLANG_ASSERT(diffPrimalVal);
+
auto primalDiffVal = findOrTranscribeDiffInst(builder, origInst->getPrimalValue());
+ if (!primalDiffVal)
+ primalDiffVal = getDifferentialZeroOfType(builder, origInst->getPrimalValue()->getDataType());
SLANG_ASSERT(primalDiffVal);
+
auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue());
+ if (!diffDiffVal)
+ diffDiffVal = getDifferentialZeroOfType(builder, origInst->getDifferentialValue()->getDataType());
SLANG_ASSERT(diffDiffVal);
auto primalPairType = findOrTranscribePrimalInst(builder, origInst->getFullType());
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 8f218293d..0f2ceceb4 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -510,7 +510,7 @@ namespace Slang
eliminateMultiLevelBreakForFunc(func->getModule(), func);
IRCFGNormalizationPass cfgPass = {this->getSink()};
- normalizeCFG(func);
+ normalizeCFG(autoDiffSharedContext->sharedBuilder, func);
AutoDiffAddressConversionPolicy cvtPolicty;
cvtPolicty.diffTypeContext = &diffTypeContext;
@@ -544,6 +544,8 @@ namespace Slang
// reversible.
if (SLANG_FAILED(prepareFuncForBackwardDiff(primalFunc)))
return diffPropagateFunc;
+
+ autoDiffSharedContext->sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
// Forward transcribe the clone of the original func.
ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>(
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 91374e006..520c6d276 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -911,12 +911,14 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
// Tag the differential inst using a decoration (if it doesn't have one)
if (!pair.differential->findDecoration<IRDifferentialInstDecoration>() &&
- !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>())
+ !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>() &&
+ !as<IRConstant>(pair.differential))
{
// TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential
// instead.
//
- builder->markInstAsDifferential(pair.differential, as<IRType>(pair.primal->getDataType()));
+ auto primalType = as<IRType>(pair.primal->getDataType());
+ builder->markInstAsDifferential(pair.differential, primalType);
}
break;
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index f87aa7751..5aad6e3a3 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1694,9 +1694,6 @@ struct DiffTransposePass
IRInst* newInst = builder->emitMakeVector(targetType, operands.getCount(), operands.getBuffer());
- if (isDifferentialInst(inst))
- builder->markInstAsDifferential(newInst);
-
return newInst;
}
@@ -1725,6 +1722,11 @@ struct DiffTransposePass
builder->setInsertAfter(operand);
IRInst* newOperand = promoteToType(builder, targetType, operand);
+
+ if (isDifferentialInst(operand))
+ builder->markInstAsDifferential(
+ newOperand, tryGetPrimalTypeFromDiffInst(fwdInst));
+
newOperands.add(newOperand);
needNewInst = true;
@@ -1747,7 +1749,8 @@ struct DiffTransposePass
builder->setInsertLoc(oldLoc);
if (isDifferentialInst(fwdInst))
- builder->markInstAsDifferential(newInst);
+ builder->markInstAsDifferential(
+ newInst, tryGetPrimalTypeFromDiffInst(fwdInst));
return newInst;
}
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 44cb2aa09..daf6e44d4 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -322,7 +322,10 @@ struct ExtractPrimalFuncContext
{
if (shouldStoreInst(inst))
{
- builder.setInsertAfter(inst);
+ if (as<IRParam>(inst))
+ builder.setInsertBefore(block->getFirstOrdinaryInst());
+ else
+ builder.setInsertAfter(inst);
storeInst(builder, inst, outIntermediary);
}
else if (inst->getOp() == kIROp_Var)