summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-02-22 22:22:26 -0500
committerGitHub <noreply@github.com>2023-02-22 19:22:26 -0800
commite8c08e7ecb1124f115a1d1042277776193122b57 (patch)
tree9c1d970c8be244aa4a32762e1de3338507d24444 /source
parent6eb0b4dea4da1fc21767c86cc0837d0c8b68063b (diff)
Fixed hoisting of intermediate array & context vars (#2674)
Also added legalization for loops
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.cpp1
-rw-r--r--source/slang/slang-ir-autodiff-rev.h2
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h8
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp113
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h143
5 files changed, 201 insertions, 66 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp
index d64c6d1f6..9116f67e9 100644
--- a/source/slang/slang-ir-autodiff-cfg-norm.cpp
+++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp
@@ -409,6 +409,7 @@ struct CFGNormalizationPass
// false -> atleast one break statement hit.
//
info.breakVar = builder.emitVar(builder.getBoolType());
+ builder.addNameHintDecoration(info.breakVar, UnownedStringSlice("_bflag"));
builder.emitStore(info.breakVar, builder.getBoolValue(true));
// If the loop is trivial (i.e. single iteration, with no
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
index 94bc1ef81..86a6f2846 100644
--- a/source/slang/slang-ir-autodiff-rev.h
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -32,8 +32,8 @@ struct ParameterBlockTransposeInfo
// The value with which a primal specific parameter should be replaced in propagate func.
OrderedDictionary<IRInst*, IRInst*> mapPrimalSpecificParamToReplacementInPropFunc;
-
// The insts added that is specific for propagate functions and should be removed
+
// from the future primal func.
List<IRInst*> propagateFuncSpecificPrimalInsts;
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 95ad58586..a4c79d09a 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1138,7 +1138,11 @@ struct DiffTransposePass
IRInst* hoistPrimalInst(IRBuilder* revBuilder, IRInst* inst)
{
- SLANG_RELEASE_ASSERT(isPrimalInst(inst));
+ if (as<IRBlock>(inst->getParent()) &&
+ isDifferentialInst(as<IRBlock>(inst->getParent())))
+ {
+ SLANG_RELEASE_ASSERT(isPrimalInst(inst));
+ }
// Are the operands of this primal inst also available in the reverse-mode context?
// If not, move/load them.
@@ -1379,7 +1383,7 @@ struct DiffTransposePass
// In order to perform the call, we need a temporary var to store the DiffPair.
auto pairType = as<IRPtrTypeBase>(arg->getDataType())->getValueType();
auto tempVar = builder->emitVar(pairType);
- auto primalVal = builder->emitLoad(instPair->getPrimal());
+ auto primalVal = builder->emitLoad(hoistPrimalInst(builder, instPair->getPrimal()));
auto diffVal = builder->emitLoad(instPair->getDiff());
auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal);
builder->emitStore(tempVar, pairVal);
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 50c5c4ea6..096751836 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -130,6 +130,91 @@ struct ExtractPrimalFuncContext
}
}
+ bool doesInstHaveDiffUse(IRInst* inst)
+ {
+ bool hasDiffUser = false;
+
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (isDiffInst(user))
+ {
+ // Ignore uses that is a return or MakeDiffPair
+ switch (user->getOp())
+ {
+ case kIROp_Return:
+ continue;
+ case kIROp_MakeDifferentialPair:
+ if (!user->hasMoreThanOneUse() && user->firstUse &&
+ user->firstUse->getUser()->getOp() == kIROp_Return)
+ continue;
+ break;
+ default:
+ break;
+ }
+ hasDiffUser = true;
+ break;
+ }
+ }
+
+ return hasDiffUser;
+ }
+
+ bool doesInstHaveStore(IRInst* inst)
+ {
+ SLANG_RELEASE_ASSERT(as<IRPtrTypeBase>(inst->getDataType()));
+
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ if (as<IRStore>(use->getUser()))
+ return true;
+
+ if (as<IRPtrTypeBase>(use->getUser()->getDataType()))
+ {
+ if (doesInstHaveStore(use->getUser()))
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ bool isIntermediateContextType(IRType* type)
+ {
+ switch (type->getOp())
+ {
+ case kIROp_BackwardDiffIntermediateContextType:
+ return true;
+ case kIROp_PtrType:
+ return isIntermediateContextType(as<IRPtrTypeBase>(type)->getValueType());
+ case kIROp_ArrayType:
+ return isIntermediateContextType(as<IRArrayType>(type)->getElementType());
+ }
+
+ return false;
+ }
+
+ bool shouldStoreVar(IRVar* var)
+ {
+ // Always store intermediate context var.
+ if (var->findDecoration<IRBackwardDerivativePrimalContextDecoration>())
+ {
+ return true;
+ }
+
+ if (isIntermediateContextType(var->getDataType()))
+ {
+ return true;
+ }
+
+ // For now the store policy is simple, we use two conditions:
+ // 1. Is the var used in a differential block and,
+ // 2. Does the var have a store
+ //
+
+ return (doesInstHaveDiffUse(var) && doesInstHaveStore(var));
+ }
+
bool shouldStoreInst(IRInst* inst)
{
if (!inst->getDataType())
@@ -181,29 +266,7 @@ struct ExtractPrimalFuncContext
}
// Only store if the inst has differential inst user.
- bool hasDiffUser = false;
- for (auto use = inst->firstUse; use; use = use->nextUse)
- {
- auto user = use->getUser();
- if (isDiffInst(user))
- {
- // Ignore uses that is a return or MakeDiffPair
- switch (user->getOp())
- {
- case kIROp_Return:
- continue;
- case kIROp_MakeDifferentialPair:
- if (!user->hasMoreThanOneUse() && user->firstUse &&
- user->firstUse->getUser()->getOp() == kIROp_Return)
- continue;
- break;
- default:
- break;
- }
- hasDiffUser = true;
- break;
- }
- }
+ bool hasDiffUser = doesInstHaveDiffUse(inst);
if (!hasDiffUser)
return false;
@@ -303,8 +366,7 @@ struct ExtractPrimalFuncContext
}
else if (inst->getOp() == kIROp_Var)
{
- // Always store intermediate context var.
- if (inst->findDecoration<IRBackwardDerivativePrimalContextDecoration>())
+ if (shouldStoreVar(as<IRVar>(inst)))
{
auto field = addIntermediateContextField(cast<IRPtrTypeBase>(inst->getDataType())->getValueType(), outIntermediary);
builder.setInsertBefore(inst);
@@ -313,6 +375,7 @@ struct ExtractPrimalFuncContext
inst->replaceUsesWith(fieldAddr);
builder.addPrimalValueStructKeyDecoration(inst, field->getKey());
}
+
}
}
}
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index e2c84ce8b..2ebc330f0 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -414,6 +414,7 @@ struct DiffUnzipPass
// Make variable in the top-most block (so it's visible to diff blocks)
region->primalCountLastVar = builder.emitVar(builder.getIntType());
+ builder.addNameHintDecoration(region->primalCountLastVar, UnownedStringSlice("_pc_last_var"));
{
IRBlock* primalInitBlock = as<IRBlock>(primalMap[region->initBlock]);
@@ -432,6 +433,7 @@ struct DiffUnzipPass
primalCondBlock,
builder.getIntType(),
phiCounterArgLoopEntryIndex);
+ builder.addNameHintDecoration(region->primalCountParam, UnownedStringSlice("_pc"));
builder.addLoopCounterDecoration(region->primalCountParam);
builder.markInstAsPrimal(region->primalCountParam);
@@ -471,6 +473,7 @@ struct DiffUnzipPass
diffCondBlock,
builder.getIntType(),
phiCounterArgLoopEntryIndex);
+ builder.addNameHintDecoration(region->diffCountParam, UnownedStringSlice("_dc"));
builder.addLoopCounterDecoration(region->diffCountParam);
builder.markInstAsPrimal(region->diffCountParam);
@@ -535,33 +538,11 @@ struct DiffUnzipPass
as<IRFuncType>(child->getDataType()) ||
as<IRTypeKind>(child->getDataType()))
continue;
-
- // We also don't care about pointer types (only Loads)
- if (auto ptrType = as<IRPtrTypeBase>(child->getDataType()))
- {
- // There's an exception to this, if the var is an intermediate context type
- // variable since there won't be a load from this yet (the load will
- // be inserted later during the transposition process)
- //
- if (as<IRBackwardDiffIntermediateContextType>(ptrType->getValueType()))
- primalInsts.add(child);
-
- continue;
- }
primalInsts.add(child);
}
IRBuilder builder(autodiffContext->moduleInst->getModule());
-
- // Build list of indices that this block is affected by.
- List<IndexedRegion*> regions;
- {
- IndexedRegion* region = indexRegionMap[fwdBlock];
- for (; region; region = region->parent)
- regions.add(region);
- }
-
for (auto inst : primalInsts)
{
@@ -581,43 +562,115 @@ struct DiffUnzipPass
if (!shouldStore) continue;
- // 2. Emit an array to top-level to allocate space.
-
- builder.setInsertBefore(firstPrimalBlock->getTerminator());
+ // 2. If we're dealing with a var, we need to locate the value that
+ // we actually need to store. We assume everything is SSA form
+ // so there must be a single IRStore on this var.
+ //
+ IRInst* valueToStore = nullptr;
+ IRBlock* valueBlock = nullptr;
+ IRType* valueType = nullptr;
- IRType* arrayType = inst->getDataType();
bool isPtrType = false;
+ bool isIntermediateContext = false;
- if (auto ptrType = as<IRPtrTypeBase>(arrayType))
+ if (auto ptrValueType = as<IRPtrTypeBase>(inst->getDataType()))
{
- SLANG_RELEASE_ASSERT(as<IRBackwardDiffIntermediateContextType>(ptrType->getValueType()));
- arrayType = ptrType->getValueType();
isPtrType = true;
+
+ // Find value to store
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ if (auto storeInst = as<IRStore>(use->getUser()))
+ {
+ // Should not see more than one IRStore
+ SLANG_RELEASE_ASSERT(!valueToStore);
+ valueToStore = storeInst->getVal();
+
+ // Is this the right block to use to determine if the
+ // store can have multiple values based on the index?
+ //
+ valueBlock = as<IRBlock>(storeInst->getParent());
+ }
+ }
+
+ if (as<IRBackwardDiffIntermediateContextType>(ptrValueType->getValueType()))
+ {
+ isIntermediateContext = true;
+
+ // TODO: This should be the parent block of the `call` associated
+ // with this context type. The var itself _could_ be in a different place.
+ //
+ valueBlock = as<IRBlock>(inst->getParent());
+ }
+
+ valueType = ptrValueType->getValueType();
+ }
+ else
+ {
+ isPtrType = false;
+ valueToStore = inst;
+ valueBlock = as<IRBlock>(inst->getParent());
+ valueType = inst->getDataType();
}
+ // What do we do for primal vars that are used in the diff block
+ // but do not have an IRStore on them? This can happen for 'out'
+ // primal variables.
+ //
+ if (!valueToStore && !isIntermediateContext)
+ {
+ // For now, we can ignore them since they are used as inputs
+ // to 'out' parameters. If their value is every actually used,
+ // we will see an IRLoad which will be hoisted accordingly.
+ //
+ continue;
+ }
+
+ // Build list of indices that the value's block is affected by.
+ List<IndexedRegion*> regions;
+ {
+ IndexedRegion* region = indexRegionMap[valueBlock];
+ for (; region; region = region->parent)
+ regions.add(region);
+ }
+
+ // 3. Emit an array to top-level to allocate space.
+
+ builder.setInsertBefore(firstPrimalBlock->getTerminator());
+
+ IRType* storageType = valueType;
+
for (auto region : regions)
{
SLANG_ASSERT(region->status == IndexedRegion::CountStatus::Static);
SLANG_ASSERT(region->maxIters >= 0);
- arrayType = builder.getArrayType(
- arrayType,
+ storageType = builder.getArrayType(
+ storageType,
builder.getIntValue(
builder.getUIntType(),
region->maxIters + 1));
}
- // Reverse the list since the indices needs to be
+ // Reverse the list since the indices need to be
// emitted in reverse order.
//
regions.reverse();
- auto storageVar = builder.emitVar(arrayType);
+ auto storageVar = builder.emitVar(storageType);
+ if (isIntermediateContext)
+ builder.addBackwardDerivativePrimalContextDecoration(
+ storageVar,
+ storageVar);
- // 3. Store current value into the array and replace uses with a load.
+ // 4. Store current value into the array and replace uses with a load.
// TODO: If an index is missing, use the 'last' value of the primal index.
+
{
- setInsertAfterOrdinaryInst(&builder, inst);
+ if (!isIntermediateContext)
+ setInsertAfterOrdinaryInst(&builder, valueToStore);
+ else
+ setInsertAfterOrdinaryInst(&builder, inst);
IRInst* storeAddr = storageVar;
IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType();
@@ -631,11 +684,25 @@ struct DiffUnzipPass
storeAddr,
region->primalCountParam);
}
-
- builder.emitStore(storeAddr, inst);
+
+ if (!isIntermediateContext)
+ builder.emitStore(storeAddr, valueToStore);
+ else
+ {
+ List<IRUse*> primalUses;
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ if (!isDifferentialInst(getBlock(use->getUser())))
+ primalUses.add(use);
+ }
+
+ for (auto use : primalUses)
+ use->set(storeAddr);
+ }
}
+
- // 4. Replace uses in differential blocks with loads from the array.
+ // 5. Replace uses in differential blocks with loads from the array.
List<IRInst*> instsToTag;
{
List<IRUse*> diffUses;