summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-13 11:05:29 -0800
committerGitHub <noreply@github.com>2023-02-13 11:05:29 -0800
commitec49215d711fff9356663390a31182e811e27467 (patch)
treec97322a75faf55cd7dfc3b729a4f951571c70bb7 /source
parent977eb925b7e9cb1a763c1e5563b2bc605b6476d6 (diff)
Various auto-diff bug fixes. (#2646)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp3
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp10
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h32
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h4
-rw-r--r--source/slang/slang-ir-peephole.cpp6
5 files changed, 45 insertions, 10 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 7782bd39c..b5d3dba10 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1514,7 +1514,8 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam
auto diffType = differentiateType(builder, cast<IRPtrTypeBase>(origParam->getDataType())->getValueType());
auto diff = builder->emitVar(diffType);
- builder->markInstAsDifferential(diff, ptrInnerPairType->getValueType());
+ builder->markInstAsDifferential(
+ diff, builder->getPtrType(ptrInnerPairType->getValueType()));
IRInst* primalInitVal = nullptr;
IRInst* diffInitVal = nullptr;
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 20090ca42..ff8ece76c 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -539,6 +539,10 @@ namespace Slang
{
builder.emitStore(tempVar, builder.emitLoad(param));
}
+ else
+ {
+ builder.emitStore(tempVar, builder.emitDefaultConstruct(ptrType->getValueType()));
+ }
}
for (auto block : func->getBlocks())
@@ -589,6 +593,7 @@ namespace Slang
AutoDiffAddressConversionPolicy cvtPolicty;
cvtPolicty.diffTypeContext = &diffTypeContext;
auto result = eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink);
+
if (SLANG_SUCCEEDED(result))
{
simplifyFunc(func);
@@ -824,6 +829,7 @@ namespace Slang
moveInstChildren(existingPrimalHeader, primalFuncGeneric);
primalFuncGeneric->replaceUsesWith(existingPrimalHeader);
primalFuncGeneric->removeAndDeallocate();
+ primalFuncGeneric = existingPrimalHeader;
}
else
{
@@ -831,7 +837,7 @@ namespace Slang
builder->addBackwardDerivativePrimalDecoration(primalFunc, specializedBackwardPrimalFunc);
}
- initializeLocalVariables(builder->getSharedBuilder(), primalFunc);
+ initializeLocalVariables(builder->getSharedBuilder(), as<IRGlobalValueWithCode>(getGenericReturnVal(primalFuncGeneric)));
initializeLocalVariables(builder->getSharedBuilder(), diffPropagateFunc);
}
@@ -957,7 +963,6 @@ namespace Slang
// after transposition.
auto tempVar = nextBlockBuilder.emitVar(diffType);
copyNameHintDecoration(tempVar, fwdParam);
- nextBlockBuilder.markInstAsDifferential(tempVar, diffPairType);
// Initialize the var with input diff param at start.
// Note that we insert the store in the primal block so it won't get transposed.
@@ -1088,7 +1093,6 @@ namespace Slang
auto diffVar = nextBlockBuilder.emitVar(diffType);
copyNameHintDecoration(diffVar, fwdParam);
result.propagateFuncSpecificPrimalInsts.add(diffVar);
- diffBuilder.markInstAsDifferential(diffVar, diffPairType);
diffRefReplacement = diffVar;
// Clear the diff read var to zero at start of the function.
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 2953c6206..4e1532153 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -506,6 +506,7 @@ struct DiffTransposePass
List<IRBlock*> traverseWorkList;
HashSet<IRBlock*> traverseSet;
traverseWorkList.add(revDiffFunc->getFirstBlock());
+
traverseSet.Add(revDiffFunc->getFirstBlock());
for (IRBlock* block = revDiffFunc->getFirstBlock(); block; block = block->getNextBlock())
{
@@ -517,9 +518,13 @@ struct DiffTransposePass
// or entirely with differential insts.
continue;
}
+
workList.add(block);
}
+ if (!workList.getCount())
+ return;
+
// Reverse the order of the blocks.
workList.reverse();
@@ -533,7 +538,32 @@ struct DiffTransposePass
// Keep track of first diff block, since this is where
// we'll emit temporary vars to hold per-block derivatives.
//
- firstRevDiffBlockMap[revDiffFunc] = revBlockMap[terminalDiffBlocks[0]];
+ auto firstRevDiffBlock = revBlockMap[terminalDiffBlocks[0]].GetValue();
+ firstRevDiffBlockMap[revDiffFunc] = firstRevDiffBlock;
+
+ // Move all diff vars to first block, and initialize them with zero.
+ builder.setInsertInto(firstRevDiffBlock);
+ for (auto block : workList)
+ {
+ for (auto inst = block->getFirstInst(); inst;)
+ {
+ auto nextInst = inst->getNextInst();
+ if (auto varInst = as<IRVar>(inst))
+ {
+ if (auto diffDecor = varInst->findDecoration<IRDifferentialInstDecoration>())
+ {
+ if (auto ptrPrimalType = as<IRPtrTypeBase>(diffDecor->getPrimalType()))
+ {
+ varInst->insertAtEnd(firstRevDiffBlock);
+
+ auto dzero = emitDZeroOfDiffInstType(&builder, ptrPrimalType->getValueType());
+ builder.emitStore(varInst, dzero);
+ }
+ }
+ }
+ inst = nextInst;
+ }
+ }
for (auto block : workList)
{
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 1a85ea6a4..d83ff57e4 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -839,7 +839,7 @@ struct DiffUnzipPass
auto diffType = (IRType*) diffTypeContext.getDifferentialForType(primalBuilder, primalType);
auto primalVar = primalBuilder->emitVar(primalType);
auto diffVar = diffBuilder->emitVar(diffType);
- diffBuilder->markInstAsDifferential(diffVar, primalType);
+ diffBuilder->markInstAsDifferential(diffVar, diffBuilder->getPtrType(primalType));
return InstPair(primalVar, diffVar);
}
@@ -874,7 +874,7 @@ struct DiffUnzipPass
// If return value is not differentiable, just turn it into a trivial branch.
auto primalBranch = primalBuilder->emitBranch(firstDiffBlock);
primalBuilder->addBackwardDerivativePrimalReturnDecoration(
- primalBranch, primalBuilder->getVoidValue());
+ primalBranch, mixedReturn->getVal());
auto returnInst = diffBuilder->emitReturn();
diffBuilder->markInstAsDifferential(returnInst, nullptr);
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index 4dbe6d2cb..87c31ffb7 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -427,7 +427,7 @@ struct PeepholeContext : InstPassBase
else
break;
if (i == (IRIntegerValue)constIndex->getValue())
- arg = inst->getOperand(2);
+ arg = updateInst->getElementValue();
args.add(arg);
}
if (args.getCount() == arraySize->getValue())
@@ -456,8 +456,8 @@ struct PeepholeContext : InstPassBase
IRInst* arg = nullptr;
if (i < oldVal->getOperandCount())
arg = oldVal->getOperand(i);
- if (field->getKey() == inst->getOperand(1))
- arg = inst->getOperand(2);
+ if (field->getKey() == key)
+ arg = updateInst->getElementValue();
if (arg)
{
args.add(arg);