diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-13 11:05:29 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-13 11:05:29 -0800 |
| commit | ec49215d711fff9356663390a31182e811e27467 (patch) | |
| tree | c97322a75faf55cd7dfc3b729a4f951571c70bb7 /source | |
| parent | 977eb925b7e9cb1a763c1e5563b2bc605b6476d6 (diff) | |
Various auto-diff bug fixes. (#2646)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 32 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 6 |
5 files changed, 45 insertions, 10 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 7782bd39c..b5d3dba10 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1514,7 +1514,8 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam auto diffType = differentiateType(builder, cast<IRPtrTypeBase>(origParam->getDataType())->getValueType()); auto diff = builder->emitVar(diffType); - builder->markInstAsDifferential(diff, ptrInnerPairType->getValueType()); + builder->markInstAsDifferential( + diff, builder->getPtrType(ptrInnerPairType->getValueType())); IRInst* primalInitVal = nullptr; IRInst* diffInitVal = nullptr; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 20090ca42..ff8ece76c 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -539,6 +539,10 @@ namespace Slang { builder.emitStore(tempVar, builder.emitLoad(param)); } + else + { + builder.emitStore(tempVar, builder.emitDefaultConstruct(ptrType->getValueType())); + } } for (auto block : func->getBlocks()) @@ -589,6 +593,7 @@ namespace Slang AutoDiffAddressConversionPolicy cvtPolicty; cvtPolicty.diffTypeContext = &diffTypeContext; auto result = eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink); + if (SLANG_SUCCEEDED(result)) { simplifyFunc(func); @@ -824,6 +829,7 @@ namespace Slang moveInstChildren(existingPrimalHeader, primalFuncGeneric); primalFuncGeneric->replaceUsesWith(existingPrimalHeader); primalFuncGeneric->removeAndDeallocate(); + primalFuncGeneric = existingPrimalHeader; } else { @@ -831,7 +837,7 @@ namespace Slang builder->addBackwardDerivativePrimalDecoration(primalFunc, specializedBackwardPrimalFunc); } - initializeLocalVariables(builder->getSharedBuilder(), primalFunc); + initializeLocalVariables(builder->getSharedBuilder(), as<IRGlobalValueWithCode>(getGenericReturnVal(primalFuncGeneric))); initializeLocalVariables(builder->getSharedBuilder(), diffPropagateFunc); } @@ -957,7 +963,6 @@ namespace Slang // after transposition. auto tempVar = nextBlockBuilder.emitVar(diffType); copyNameHintDecoration(tempVar, fwdParam); - nextBlockBuilder.markInstAsDifferential(tempVar, diffPairType); // Initialize the var with input diff param at start. // Note that we insert the store in the primal block so it won't get transposed. @@ -1088,7 +1093,6 @@ namespace Slang auto diffVar = nextBlockBuilder.emitVar(diffType); copyNameHintDecoration(diffVar, fwdParam); result.propagateFuncSpecificPrimalInsts.add(diffVar); - diffBuilder.markInstAsDifferential(diffVar, diffPairType); diffRefReplacement = diffVar; // Clear the diff read var to zero at start of the function. diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 2953c6206..4e1532153 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -506,6 +506,7 @@ struct DiffTransposePass List<IRBlock*> traverseWorkList; HashSet<IRBlock*> traverseSet; traverseWorkList.add(revDiffFunc->getFirstBlock()); + traverseSet.Add(revDiffFunc->getFirstBlock()); for (IRBlock* block = revDiffFunc->getFirstBlock(); block; block = block->getNextBlock()) { @@ -517,9 +518,13 @@ struct DiffTransposePass // or entirely with differential insts. continue; } + workList.add(block); } + if (!workList.getCount()) + return; + // Reverse the order of the blocks. workList.reverse(); @@ -533,7 +538,32 @@ struct DiffTransposePass // Keep track of first diff block, since this is where // we'll emit temporary vars to hold per-block derivatives. // - firstRevDiffBlockMap[revDiffFunc] = revBlockMap[terminalDiffBlocks[0]]; + auto firstRevDiffBlock = revBlockMap[terminalDiffBlocks[0]].GetValue(); + firstRevDiffBlockMap[revDiffFunc] = firstRevDiffBlock; + + // Move all diff vars to first block, and initialize them with zero. + builder.setInsertInto(firstRevDiffBlock); + for (auto block : workList) + { + for (auto inst = block->getFirstInst(); inst;) + { + auto nextInst = inst->getNextInst(); + if (auto varInst = as<IRVar>(inst)) + { + if (auto diffDecor = varInst->findDecoration<IRDifferentialInstDecoration>()) + { + if (auto ptrPrimalType = as<IRPtrTypeBase>(diffDecor->getPrimalType())) + { + varInst->insertAtEnd(firstRevDiffBlock); + + auto dzero = emitDZeroOfDiffInstType(&builder, ptrPrimalType->getValueType()); + builder.emitStore(varInst, dzero); + } + } + } + inst = nextInst; + } + } for (auto block : workList) { diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 1a85ea6a4..d83ff57e4 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -839,7 +839,7 @@ struct DiffUnzipPass auto diffType = (IRType*) diffTypeContext.getDifferentialForType(primalBuilder, primalType); auto primalVar = primalBuilder->emitVar(primalType); auto diffVar = diffBuilder->emitVar(diffType); - diffBuilder->markInstAsDifferential(diffVar, primalType); + diffBuilder->markInstAsDifferential(diffVar, diffBuilder->getPtrType(primalType)); return InstPair(primalVar, diffVar); } @@ -874,7 +874,7 @@ struct DiffUnzipPass // If return value is not differentiable, just turn it into a trivial branch. auto primalBranch = primalBuilder->emitBranch(firstDiffBlock); primalBuilder->addBackwardDerivativePrimalReturnDecoration( - primalBranch, primalBuilder->getVoidValue()); + primalBranch, mixedReturn->getVal()); auto returnInst = diffBuilder->emitReturn(); diffBuilder->markInstAsDifferential(returnInst, nullptr); diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index 4dbe6d2cb..87c31ffb7 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -427,7 +427,7 @@ struct PeepholeContext : InstPassBase else break; if (i == (IRIntegerValue)constIndex->getValue()) - arg = inst->getOperand(2); + arg = updateInst->getElementValue(); args.add(arg); } if (args.getCount() == arraySize->getValue()) @@ -456,8 +456,8 @@ struct PeepholeContext : InstPassBase IRInst* arg = nullptr; if (i < oldVal->getOperandCount()) arg = oldVal->getOperand(i); - if (field->getKey() == inst->getOperand(1)) - arg = inst->getOperand(2); + if (field->getKey() == key) + arg = updateInst->getElementValue(); if (arg) { args.add(arg); |
