diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-01-04 23:40:13 +0530 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-04 10:10:13 -0800 |
| commit | 7f64b2a9e3eb7aea13de550bd24c1aea7787c94b (patch) | |
| tree | 40afc50c9fb227b8728487403d3f9b712a1509b2 /source | |
| parent | e8f977a00f5d131ec2d51d2a026d6452e8f762f0 (diff) | |
Multi-block reverse-mode autodiff (#2576)
* Initial multi-block implementation
* Implemented multi-block reverse-mode (without loops)
* Added logic to remove block-level decorations to avoid confusing IR simplification passes
* Fixed issues with block-level decorations during IR simplification by removing them prior to simplification.
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 23 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 631 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 95 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 181 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 51 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 4 |
9 files changed, 847 insertions, 159 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index dbf79b5f8..c245701df 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -519,6 +519,11 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns // block to compute *both* primals and derivatives (i.e linearized block) SLANG_ASSERT(diffBranch); + // Since blocks always compute both primals and differentials, the branch + // instructions are also always mixed. + // + builder->markInstAsMixedDifferential(diffBranch); + return InstPair(diffBranch, diffBranch); } @@ -740,6 +745,7 @@ InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* orig kIROp_loop, diffLoopOperands.getCount(), diffLoopOperands.getBuffer()); + builder->markInstAsMixedDifferential(diffLoop); return InstPair(diffLoop, diffLoop); } @@ -779,13 +785,14 @@ InstPair ForwardDiffTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse* diffIfElseArgs.add(primalOperand); } - IRInst* diffLoop = builder->emitIntrinsicInst( + IRInst* diffIfElse = builder->emitIntrinsicInst( nullptr, kIROp_ifElse, diffIfElseArgs.getCount(), diffIfElseArgs.getBuffer()); + builder->markInstAsMixedDifferential(diffIfElse); - return InstPair(diffLoop, diffLoop); + return InstPair(diffIfElse, diffIfElse); } InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst) @@ -963,10 +970,16 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr builder.setInsertInto(diffFunc); differentiableTypeConformanceContext.setFunc(primalFunc); + // Transcribe children from origFunc into diffFunc for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) this->transcribe(&builder, block); + // Some of the transcribed blocks can appear 'out-of-order'. Although this + // shouldn't be an issue, for consistency, we put them back in order. + for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) + as<IRBlock>(lookupDiffInst(block))->insertAtEnd(diffFunc); + return InstPair(primalFunc, diffFunc); } @@ -1124,6 +1137,12 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* return trascribeNonDiffInst(builder, origInst); case kIROp_StructKey: return InstPair(origInst, nullptr); + case kIROp_Unreachable: + { + auto unreachInst = builder->emitUnreachable(); + builder->markInstAsMixedDifferential(unreachInst); + return InstPair(unreachInst, nullptr); + } case kIROp_MakeExistentialWithRTTI: SLANG_UNEXPECTED("MakeExistentialWithRTTI inst is not expected in autodiff pass."); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index cfee49eb1..ae9b69f61 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -429,11 +429,6 @@ namespace Slang block->insertAtEnd(diffFunc); } - // Extracts the primal computations into its own func, and replace the primal insts - // with the intermediate results computed from the extracted func. - IRInst* intermediateType = nullptr; - auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffFunc, unzippedFwdDiffFunc, intermediateType); - // Transpose the first block (parameter block) transposeParameterBlock(builder, diffFunc); @@ -445,7 +440,12 @@ namespace Slang DiffTransposePass::FuncTranspositionInfo info = {dOutParameter, nullptr}; diffTransposePass->transposeDiffBlocksInFunc(diffFunc, info); - // Clean up by deallocating intermediate steps. + // Extracts the primal computations into its own func, and replace the primal insts + // with the intermediate results computed from the extracted func. + IRInst* intermediateType = nullptr; + auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffFunc, unzippedFwdDiffFunc, intermediateType); + + // Clean up by deallocating intermediate versions. tempDiffFunc->removeAndDeallocate(); unzippedFwdDiffFunc->removeAndDeallocate(); fwdDiffFunc->removeAndDeallocate(); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index da7762908..69cef941c 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -598,8 +598,9 @@ InstPair AutoDiffTranscriberBase::transcribeBlock(IRBuilder* builder, IRBlock* o { IRBuilder subBuilder(builder->getSharedBuilder()); subBuilder.setInsertLoc(builder->getInsertLoc()); - + IRInst* diffBlock = subBuilder.emitBlock(); + subBuilder.markInstAsMixedDifferential(diffBlock); // Note: for blocks, we setup the mapping _before_ // processing the children since we could encounter diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index a14ecad84..436a17a7f 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -54,6 +54,24 @@ struct DiffTransposePass Flavor flavor; }; + struct BlockPredecessorEntry + { + // Previous block. + IRBlock* prevBlock; + + // Integer value corresponding to this predecessor. + IRIntegerValue* indexVal; + }; + + struct ControlFlowTranspositionInfo + { + // Variable used for recording control flow. + IRVar* controlVar; + + // Info about all possible predecessor blocks. + Dictionary<IRBlock*, IRInst*> predEntries; + }; + DiffTransposePass(AutoDiffSharedContext* autodiffContext) : autodiffContext(autodiffContext), pairBuilder(autodiffContext), diffTypeContext(autodiffContext) { } @@ -90,6 +108,11 @@ struct DiffTransposePass { // Grab all differentiable type information. diffTypeContext.setFunc(revDiffFunc); + + // Note down terminal primal and terminal differential blocks + // since we need to link them up at the end. + auto terminalPrimalBlocks = getTerminalPrimalBlocks(revDiffFunc); + auto terminalDiffBlocks = getTerminalDiffBlocks(revDiffFunc); // Traverse all instructions/blocks in reverse (starting from the terminator inst) // look for insts/blocks marked with IRDifferentialInstDecoration, @@ -117,9 +140,20 @@ struct DiffTransposePass workList.add(block); } - // TODO: We *might* need a step here that 'sorts' the work list in reverse order starting with 'leaf' - // differential blocks, and following the branches backwards. - // The alternative is to make phi nodes and treat all intermediaries & their gradients as arguments. + // Reverse the order of the blocks. + workList.reverse(); + + // Emit empty rev-mode blocks for every fwd-mode block. + for (auto block : workList) + { + revBlockMap[block] = builder.emitBlock(); + builder.markInstAsDifferential(revBlockMap[block]); + } + + // Keep track of first diff block, since this is where + // we'll emit temporary vars to hold per-block derivatives. + // + firstRevDiffBlockMap[revDiffFunc] = revBlockMap[workList[0]]; for (auto block : workList) { @@ -129,27 +163,123 @@ struct DiffTransposePass this->addRevGradientForFwdInst(returnInst, RevGradient(returnInst, transposeInfo.dOutInst, nullptr)); } - IRBlock* revBlock = builder.emitBlock(); + IRBlock* revBlock = revBlockMap[block]; this->transposeBlock(block, revBlock); + } + + // Link the last differential fwd-mode block (which will be the first + // rev-mode block) as the successor to the last primal block. + // We assume that the original function is in single-return form + // So, there should be exactly 1 'last' block of each type. + // + { + SLANG_ASSERT(terminalPrimalBlocks.getCount() == 1); + SLANG_ASSERT(terminalDiffBlocks.getCount() == 1); + + auto terminalPrimalBlock = terminalPrimalBlocks[0]; + auto terminalRevBlock = as<IRBlock>(revBlockMap[terminalDiffBlocks[0]]); + + terminalPrimalBlock->getTerminator()->removeAndDeallocate(); + + IRBuilder subBuilder(builder.getSharedBuilder()); + subBuilder.setInsertInto(terminalPrimalBlock); + + // There should be no parameters in the first reverse-mode block. + SLANG_ASSERT(terminalRevBlock->getFirstParam() == nullptr); - // TODO: This should only really be used for the transition from - // the 'last' primal block(s) to the first differential block. - // Transitions from differential blocks to - block->replaceUsesWith(revBlock); + subBuilder.emitBranch(terminalRevBlock); + } + + // Remove fwd-mode blocks. + for (auto block : workList) + { block->removeAndDeallocate(); } } - // A[cond_inst] -> (B or C) -> D => D[cond_inst] -> (B_T -> C_T) -> A_T + // Fetch or create a gradient accumulator var + // corresponding to a inst. These are used to + // accumulate gradients across blocks. + // + IRVar* getOrCreateAccumulatorVar(IRInst* fwdInst) + { + // Check if we have a var already. + if (revAccumulatorVarMap.ContainsKey(fwdInst)) + return revAccumulatorVarMap[fwdInst]; + + IRBuilder tempVarBuilder(autodiffContext->sharedBuilder); + + IRBlock* firstDiffBlock = firstRevDiffBlockMap[as<IRFunc>(fwdInst->getParent()->getParent())]; + tempVarBuilder.setInsertBefore(firstDiffBlock->getTerminator()); + + auto primalType = tryGetPrimalTypeFromDiffInst(fwdInst); + auto diffType = fwdInst->getDataType(); + + auto zeroMethod = diffTypeContext.getZeroMethodForType( + &tempVarBuilder, + primalType); + + SLANG_ASSERT(zeroMethod); + + // Emit a var in the top-level differential block to hold the gradient, + // and initialize it. + auto tempRevVar = tempVarBuilder.emitVar(diffType); + auto diffZero = tempVarBuilder.emitCallInst( + diffType, + zeroMethod, + List<IRInst*>()); + tempVarBuilder.emitStore(tempRevVar, diffZero); + + revAccumulatorVarMap[fwdInst] = tempRevVar; + + return tempRevVar; + } + + bool isInstUsedOutsideParentBlock(IRInst* inst) + { + auto currBlock = inst->getParent(); + + for (auto use = inst->firstUse; use; use = use->nextUse) + { + if (use->getUser()->getParent() != currBlock) + return true; + } + return false; + } + void transposeBlock(IRBlock* fwdBlock, IRBlock* revBlock) { IRBuilder builder; builder.init(autodiffContext->sharedBuilder); - // Insert after the last block. + // Insert into our reverse block. builder.setInsertInto(revBlock); + // Check if this block has any 'outputs' (in the form of phi args + // sent to the successor bvock) + // + if (auto branchInst = as<IRUnconditionalBranch>(fwdBlock->getTerminator())) + { + for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++) + { + auto arg = branchInst->getArg(ii); + if (isDifferentialInst(arg)) + { + auto diffType = arg->getDataType(); + auto revParam = builder.emitParam(diffType); + + addRevGradientForFwdInst( + arg, + RevGradient( + RevGradient::Flavor::Simple, + arg, + revParam, + nullptr)); + } + } + } + // Move pointer & reference insts to the top of the reverse-mode block. List<IRInst*> nonValueInsts; for (IRInst* child = fwdBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) @@ -178,7 +308,7 @@ struct DiffTransposePass // for (IRInst* child = fwdBlock->getLastChild(); child; child = child->getPrevInst()) { - if (as<IRDecoration>(child)) + if (as<IRDecoration>(child) || as<IRParam>(child)) continue; transposeInst(&builder, child); @@ -193,22 +323,78 @@ struct DiffTransposePass // for (auto pair : gradientsMap) { - if (auto param = as<IRLoad>(pair.Key)) - accumulateGradientsForLoad(&builder, param); + if (auto loadInst = as<IRLoad>(pair.Key)) + accumulateGradientsForLoad(&builder, loadInst); } - // Emit a terminator inst. - // TODO: need a be a lot smarter here. For now, we assume a single differential - // block, so it should end in a return statement. - if (as<IRReturn>(fwdBlock->getTerminator())) + // Do the same thing with the phi parameters if the block. + List<IRInst*> phiParamRevGradInsts; + for (IRParam* param = fwdBlock->getFirstParam(); param; param = param->getNextParam()) { - // Emit a void return. - builder.emitReturn(); + if (hasRevGradients(param)) + { + auto gradients = popRevGradients(param); + + auto gradInst = emitAggregateValue( + &builder, + tryGetPrimalTypeFromDiffInst(param), + gradients); + + phiParamRevGradInsts.add(gradInst); + } } - else + + // Also handle any remaining gradients for insts that appear in prior blocks. + List<IRInst*> externInsts; // Holds insts in a different block, same function. + List<IRInst*> globalInsts; // Holds insts in the global scope. + for (auto pair : gradientsMap) { - SLANG_UNEXPECTED("Unhandled block terminator"); + auto instParent = pair.Key->getParent(); + if (instParent != fwdBlock) + { + if (instParent->getParent() == fwdBlock->getParent()) + externInsts.add(pair.Key); + + if (as<IRModuleInst>(instParent)) + globalInsts.add(pair.Key); + } } + + for (auto externInst : externInsts) + { + auto primalType = tryGetPrimalTypeFromDiffInst(externInst); + SLANG_ASSERT(primalType); + + if (auto accVar = getOrCreateAccumulatorVar(externInst)) + { + // Accumulate all gradients, including our accumulator variable, + // into one inst. + // + auto gradients = popRevGradients(externInst); + gradients.add(RevGradient(externInst, builder.emitLoad(accVar), nullptr)); + + auto gradInst = emitAggregateValue( + &builder, + primalType, + gradients); + + builder.emitStore(accVar, gradInst); + } + } + + // For now, we're not going to handle global insts, and simply ignore them + // Eventually, we want to turn these into global writes. + // + for (auto globalInst : globalInsts) + { + if (hasRevGradients(globalInst)) + popRevGradients(globalInst); + } + + // We _should_ be completely out of gradients to process at this point. + SLANG_ASSERT(gradientsMap.Count() == 0); + + emitTerminator(&builder, fwdBlock, phiParamRevGradInsts); } void transposeInst(IRBuilder* builder, IRInst* inst) @@ -242,13 +428,32 @@ struct DiffTransposePass if (!primalType) { // Check for special insts for which a reverse-mode gradient doesn't apply. - if(!as<IRStore>(inst)) + if(!as<IRStore>(inst) && !as<IRTerminatorInst>(inst)) { SLANG_UNEXPECTED("Could not resolve primal type for diff inst"); } + + // If we still can't resolve a differential type, there shouldn't + // be any gradients to aggregate. + // + SLANG_ASSERT(gradients.getCount() == 0); } - // Emit the aggregate of all the gradients here. This will form the total derivative for this inst. + // Is this inst used in another differential block? + // Emit a function-scope accumulator variable, and include it's value. + // Also, we ignore this if it's a load since those are turned into stores + // on a per-block basis. (We should change this behaviour to treat loads like + // any other inst) + // + if (isInstUsedOutsideParentBlock(inst) && !as<IRLoad>(inst)) + { + auto accVar = getOrCreateAccumulatorVar(inst); + gradients.add( + RevGradient(inst, builder->emitLoad(accVar), nullptr)); + } + + // Emit the aggregate of all the gradients here. + // This will form the total derivative for this inst. auto revValue = emitAggregateValue(builder, primalType, gradients); auto transposeResult = transposeInst(builder, inst, revValue); @@ -376,15 +581,297 @@ struct DiffTransposePass return TranspositionResult(gradients); } + + IRBlock* getPrimalBlock(IRBlock* fwdBlock) + { + if (auto fwdDiffDecoration = fwdBlock->findDecoration<IRDifferentialInstDecoration>()) + { + return as<IRBlock>(fwdDiffDecoration->getPrimalInst()); + } + + return nullptr; + } + + IRBlock* getFirstCodeBlock(IRGlobalValueWithCode* func) + { + return func->getFirstBlock()->getNextBlock(); + } + + List<IRBlock*> getTerminalPrimalBlocks(IRGlobalValueWithCode* func) + { + // 'Terminal' primal blocks are those that branch into a differential block. + List<IRBlock*> terminalPrimalBlocks; + for (auto block : func->getBlocks()) + for (auto successor : block->getSuccessors()) + if (!isDifferentialInst(block) && isDifferentialInst(successor)) + terminalPrimalBlocks.add(block); + + return terminalPrimalBlocks; + } + + List<IRBlock*> getTerminalDiffBlocks(IRGlobalValueWithCode* func) + { + // Terminal differential blocks are those with a return statement. + // Note that this method is designed to work with Fwd-Mode blocks, + // and this logic will be different for Rev-Mode blocks. + // + List<IRBlock*> terminalDiffBlocks; + for (auto block : func->getBlocks()) + if (as<IRReturn>(block->getTerminator())) + terminalDiffBlocks.add(block); + + return terminalDiffBlocks; + } + + IRInst* addPredecessorForBlock(IRBlock* block, IRBlock* predBlock) + { + if (!this->blockEntries.ContainsKey(block)) + { + // We haven't encountered this block yet, create a var for this in the + // first code block. + auto firstCodeBlock = getFirstCodeBlock(block->getParent()); + + IRBuilder subBuilder(this->autodiffContext->sharedBuilder); + subBuilder.setInsertBefore(firstCodeBlock->getTerminator()); + auto controlVar = subBuilder.emitVar(subBuilder.getUIntType()); + + ControlFlowTranspositionInfo info; + info.controlVar = controlVar; + + this->blockEntries[block] = info; + } + + auto info = this->blockEntries[block]; + + // Does precessor block already exist? + if (info.GetValue().predEntries.ContainsKey(predBlock)) + { + return info.GetValue().predEntries[predBlock]; + } + + // Otherwise, create an entry.. + auto uniqueIndex = info.GetValue().predEntries.Count(); + + IRBuilder builder(this->autodiffContext->sharedBuilder); + auto uniqueIndexLiteral = builder.getIntValue(builder.getUIntType(), uniqueIndex); + + info.GetValue().predEntries[predBlock] = uniqueIndexLiteral; + + return uniqueIndexLiteral; + } + + IRVar* getControlVar(IRBlock* block) + { + return this->blockEntries[block].GetValue().controlVar; + } + + // Inserts a block between the branch from fwdPredecessorBlock to fwdBlock, which sets a control + // variable to a unique index. + // + IRInst* insertPreludeForPredecessor(IRBlock* fwdBlock, IRBlock* fwdPredecessorBlock) + { + // Get associated primal blocks for both the differential blocks. + auto primalPredecessorBlock = getPrimalBlock(fwdPredecessorBlock); + SLANG_ASSERT(primalPredecessorBlock); + + auto primalBlock = getPrimalBlock(fwdBlock); + SLANG_ASSERT(primalBlock); + + // Add this block as a predecessor, and get an unique index (as an integer literal) + auto indexVal = addPredecessorForBlock(fwdBlock, fwdPredecessorBlock); + + IRBuilder subBuilder(this->autodiffContext->sharedBuilder); + subBuilder.setInsertInto(primalPredecessorBlock->getParent()); + + IRInst* preludeBlock = subBuilder.emitBlock(); + preludeBlock->insertAfter(primalPredecessorBlock); + + // Copy over phi parameters. + List<IRInst*> phiParams; + for (auto param = primalBlock->getFirstParam(); param; param = param->getNextParam()) + { + phiParams.add(subBuilder.emitParam(param->getDataType())); + } + + auto controlVar = getControlVar(fwdBlock); + subBuilder.emitStore(controlVar, indexVal); + + // Branch into the successor block using all the same phi parameters. + subBuilder.emitBranch(primalBlock, phiParams.getCount(), phiParams.getBuffer()); + + // Scan through uses of primalBlock to find the ones that are in + // primalPredecessorBlock, and replace them with branches to + // preludeBlock. + // + List<IRUse*> relevantUses; + for (auto use = primalBlock->firstUse; use; use = use->nextUse) + { + if (use->getUser()->getParent() == primalPredecessorBlock) + relevantUses.add(use); + } + + for (auto use : relevantUses) + use->set(preludeBlock); + + return indexVal; + } + + bool doesBlockHaveDifferentialPredecessors(IRBlock* fwdBlock) + { + for (auto block : fwdBlock->getPredecessors()) + { + if (isDifferentialInst(block)) + { + return true; + } + } + + return false; + } + + void emitTerminator(IRBuilder* builder, IRBlock* fwdBlockInst, List<IRInst*> phiParamGrads) + { + // If this block has no differential predecessors, add a return statement. + if (!doesBlockHaveDifferentialPredecessors(fwdBlockInst)) + { + // Emit a void return. + builder->emitReturn(); + return; + } + + for (auto predecessor : fwdBlockInst->getPredecessors()) + { + // Insert code into the *primal* version of the predecessor block + // to set the control variable to indexVal before branching. + // + insertPreludeForPredecessor(fwdBlockInst, predecessor); + } + + List<IRBlock*> revPredecessorBlocks; + List<IRInst*> indexVals; + + for (auto blockEntry : this->blockEntries[fwdBlockInst].GetValue().predEntries) + { + revPredecessorBlocks.add(revBlockMap[blockEntry.Key]); + indexVals.add(blockEntry.Value); + } + + auto predCount = revPredecessorBlocks.getCount(); + + SLANG_ASSERT(predCount > 0); + + List<IRBlock*> intermediateBranchBlocks; + + IRBuilder branchBlockBuilder(builder->getSharedBuilder()); + + branchBlockBuilder.setInsertInto(builder->getFunc()); + + // Make a block to unconditionally branch into predecessor-0 with the + // appropriate phi gradients. + // + auto firstBranchBlock = branchBlockBuilder.emitBlock(); + intermediateBranchBlocks.add(firstBranchBlock); + + branchBlockBuilder.markInstAsDifferential(firstBranchBlock); + branchBlockBuilder.emitBranch( + revPredecessorBlocks[0], + phiParamGrads.getCount(), + phiParamGrads.getBuffer()); + + // Create a builder to insert loads and comparison insts to figure + // out which block to branch into based on the control vars. + // This builder is set up to emit into the last _primal_ block. + // + IRBuilder booleanIndicatorBuilder(builder->getSharedBuilder()); + auto terminalPrimalBlock = getTerminalPrimalBlocks(builder->getFunc())[0]; + + booleanIndicatorBuilder.setInsertBefore(terminalPrimalBlock->getTerminator()); + + if (predCount == 1) + { + builder->emitBranch(firstBranchBlock); + } + else + { + IRBuilder ladderBlockBuilder(builder->getSharedBuilder()); + ladderBlockBuilder.setInsertInto(builder->getFunc()); + + // TODO: For now, we're trivially setting 'afterBlock' to + // the first reverse block. This is not really optimal for the + // restructuring passes since the 'then' and 'else' regions + // can have significant overlap. + // + auto firstFwdDiffBlock = (*terminalPrimalBlock->getSuccessors().begin()); + SLANG_ASSERT(firstFwdDiffBlock); + + auto defaultAfterBlock = revBlockMap[firstFwdDiffBlock]; + + auto nextLadderBlock = firstBranchBlock; + for (Index ii = 0; ii < predCount - 1; ii++) + { + // Make the 'leaf' block. This just branches into + // predecessor-i+1 with the appropriate phi args. + // + branchBlockBuilder.setInsertInto(branchBlockBuilder.getFunc()); + + auto thisIndexBlock = branchBlockBuilder.emitBlock(); + intermediateBranchBlocks.add(thisIndexBlock); + + branchBlockBuilder.markInstAsDifferential(thisIndexBlock); + branchBlockBuilder.emitBranch( + revPredecessorBlocks[ii+1], + phiParamGrads.getCount(), + phiParamGrads.getBuffer()); + + // Emit a boolean inst to represent whether we need to branch into + // block ii. + auto blockIndicatorInst = booleanIndicatorBuilder.emitEql( + booleanIndicatorBuilder.emitLoad(getControlVar(fwdBlockInst)), + indexVals[ii+1]); + + // Create a block to branch between i+1 and the rest of the ladder so far + // (0 ... i) + // + auto upperLadderBlock = ladderBlockBuilder.emitBlock(); + intermediateBranchBlocks.add(upperLadderBlock); + + ladderBlockBuilder.markInstAsDifferential(upperLadderBlock); + ladderBlockBuilder.emitIfElse( + blockIndicatorInst, + thisIndexBlock, + nextLadderBlock, + defaultAfterBlock); + + nextLadderBlock = upperLadderBlock; + } + + // Branch into the last ladder block. + builder->emitBranch(nextLadderBlock); + } + + // Insert all intermediate blocks in the order they were created, right after + // the current reverse block. + + auto revBlock = revBlockMap[fwdBlockInst]; + SLANG_ASSERT(revBlock); + + for (auto block : intermediateBranchBlocks) + { + block->insertAfter(revBlock); + } + + return; + } TranspositionResult transposeInst(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue) { + // Dispatch logic. switch(fwdInst->getOp()) { case kIROp_Add: case kIROp_Mul: - case kIROp_Sub: + case kIROp_Sub: return transposeArithmetic(builder, fwdInst, revValue); case kIROp_Call: @@ -413,6 +900,16 @@ struct DiffTransposePass case kIROp_MakeVector: return transposeMakeVector(builder, fwdInst, revValue); + + case kIROp_unconditionalBranch: + case kIROp_conditionalBranch: + case kIROp_ifElse: + case kIROp_loop: + { + // Ignore. transposeBlock() should take care of adding the + // appropriate branch instruction. + return TranspositionResult(); + } default: SLANG_ASSERT_FAILURE("Unhandled instruction"); @@ -470,7 +967,6 @@ struct DiffTransposePass return TranspositionResult(List<RevGradient>()); } - TranspositionResult transposeStore(IRBuilder* builder, IRStore* fwdStore, IRInst*) { @@ -935,71 +1431,6 @@ struct DiffTransposePass nullptr); } - IRInst* emitAggregateDifferentialPair(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> pairGradients) - { - SLANG_UNEXPECTED("Should not run."); - - auto aggPairType = as<IRDifferentialPairType>(aggPrimalType); - SLANG_ASSERT(aggPairType); - - IRType* diffType = (IRType*)pairBuilder.getDiffTypeFromPairType(builder, aggPairType); - - IRInst* primalInst = nullptr; - IRInst* diffInst = nullptr; - - List<RevGradient> gradients; - for (auto gradient : pairGradients) - { - switch (gradient.flavor) - { - case RevGradient::Flavor::Simple: - { - // In this case, the gradient is a 'pair' already, but we need to treat the primal element - // as if it didn't exist (we simply copy it over) - // If we already saw a pair, throw an error since we don't know how to combine to primals. - // (i.e. something went wrong prior to this step.) - // - if (primalInst) - { - SLANG_UNEXPECTED("Encountered multiple pair types in emitAggregateDifferentialPair"); - } - - primalInst = builder->emitDifferentialPairGetPrimal(gradient.revGradInst); - gradients.add( - RevGradient( - RevGradient::Flavor::Simple, - gradient.targetInst, - builder->emitDifferentialPairGetDifferential( - diffType, - gradient.revGradInst), - gradient.fwdGradInst)); - break; - } - - case RevGradient::Flavor::GetDifferential: - { - // In this case, the gradient is the result of transposing a GetDifferential - // so we have only the gradient part. Just add it to the list of gradients to aggregate - gradients.add( - RevGradient( - RevGradient::Flavor::Simple, - gradient.targetInst, - gradient.revGradInst, - gradient.fwdGradInst)); - break; - } - default: - SLANG_UNEXPECTED("Unexpected gradient flavor in emitAggregateDifferentialPair"); - } - } - - // Aggregate only the differentials - diffInst = emitAggregateValue(builder, aggPairType->getValueType(), gradients); - - // Pack them back together. - return builder->emitMakeDifferentialPair(aggPrimalType, primalInst, diffInst); - } - IRInst* emitAggregateValue(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients) { // If we're dealing with the differential-pair types, we need to use a different aggregation method, since @@ -1138,17 +1569,25 @@ struct DiffTransposePass return gradientsMap.ContainsKey(fwdInst); } - AutoDiffSharedContext* autodiffContext; + AutoDiffSharedContext* autodiffContext; + + DifferentiableTypeConformanceContext diffTypeContext; - DifferentiableTypeConformanceContext diffTypeContext; + DifferentialPairTypeBuilder pairBuilder; - DifferentialPairTypeBuilder pairBuilder; + Dictionary<IRInst*, List<RevGradient>> gradientsMap; - Dictionary<IRInst*, List<RevGradient>> gradientsMap; + Dictionary<IRInst*, IRVar*> revAccumulatorVarMap; + + Dictionary<IRInst*, IRInst*>* primalsMap; + + List<IRInst*> usedPtrs; + + Dictionary<IRBlock*, ControlFlowTranspositionInfo> blockEntries; - Dictionary<IRInst*, IRInst*>* primalsMap; + Dictionary<IRBlock*, IRBlock*> revBlockMap; - List<IRInst*> usedPtrs; + Dictionary<IRGlobalValueWithCode*, IRBlock*> firstRevDiffBlockMap; }; diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 546d5a6ec..2fd53dbd0 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -91,7 +91,7 @@ struct ExtractPrimalFuncContext for (UInt i = 0; i < originalFuncType->getParamCount(); i++) paramTypes.add(originalFuncType->getParamType(i)); paramTypes.add(builder.getInOutType((IRType*)outIntermediateType)); - auto newFuncType = builder.getFuncType(paramTypes, originalFuncType->getResultType()); + auto newFuncType = builder.getFuncType(paramTypes, builder.getVoidType()); return newFuncType; } @@ -100,6 +100,10 @@ struct ExtractPrimalFuncContext if (inst->findDecoration<IRDifferentialInstDecoration>() || inst->findDecoration<IRMixedDifferentialInstDecoration>()) return true; + + if (auto block = as<IRBlock>(inst->getParent())) + return isDiffInst(block); + return false; } @@ -161,6 +165,7 @@ struct ExtractPrimalFuncContext case kIROp_DoubleType: case kIROp_VectorType: case kIROp_MatrixType: + case kIROp_BoolType: case kIROp_Param: case kIROp_Specialize: case kIROp_LookupWitness: @@ -383,47 +388,76 @@ struct ExtractPrimalFuncContext genericMigrationContext.init(gen, as<IRGeneric>(spec->getBase())); } + List<IRBlock*> diffBlocksList; + List<IRBlock*> primalBlocksList; + for (auto block : func->getBlocks()) { if (block == paramBlock) continue; - if (block->findDecoration<IRDifferentialInstDecoration>() || - block->findDecoration<IRMixedDifferentialInstDecoration>()) + + if (isDiffInst(block)) + diffBlocksList.add(block); + else + primalBlocksList.add(block); + } + + // Go over primal blocks and store insts. + for (auto block : primalBlocksList) + { + // For primal insts, decide whether or not to store its result in + // output intermediary struct. + for (auto inst : block->getChildren()) { - if (block->getFirstParam() == nullptr) + if (shouldStoreInst(inst)) { - // If the block does not have any PHI nodes, just remove it and - // replace all its uses with returnBlock. - block->replaceUsesWith(returnBlock); - block->removeAndDeallocate(); - } - else - { - // If the block has Phi nodes, we can't directly replace it with - // `returnBlock`, but we can turn the block into a trivial branch - // into `returnBlock` to safely preserve the invariants of Phi nodes. - auto inst = block->getLastParam()->getNextInst(); - for (; inst; inst = inst->getNextInst()) - inst->removeAndDeallocate(); - builder.setInsertInto(block); - builder.emitBranch(returnBlock); + builder.setInsertAfter(inst); + storeInst(builder, inst, genericMigrationContext, outIntermediary); } } + } + + // Go over differential blocks and complete + for (auto block : diffBlocksList) + { + + if (block->getFirstParam() == nullptr) + { + // If the block does not have any PHI nodes, just remove it and + // replace all its uses with returnBlock. + + // TODO: This invalides the next block in the chain. Make a list first. + block->replaceUsesWith(returnBlock); + block->removeAndDeallocate(); + } else { - // For primal insts, decide whether or not to store its result in - // output intermediary struct. - for (auto inst : block->getChildren()) + // If the block has Phi nodes, we can't directly replace it with + // `returnBlock`, but we can turn the block into a trivial branch + // into `returnBlock` to safely preserve the invariants of Phi nodes. + auto inst = block->getLastParam()->getNextInst(); + for (; inst;) { - if (shouldStoreInst(inst)) - { - builder.setInsertAfter(inst); - storeInst(builder, inst, genericMigrationContext, outIntermediary); - } + auto nextInst = inst->getNextInst(); + inst->removeAndDeallocate(); + inst = nextInst; } + + builder.setInsertInto(block); + builder.emitBranch(returnBlock); } } + List<IRBlock*> unusedBlocks; + for (auto block : func->getBlocks()) + { + if (!block->hasUses() && isDiffInst(block)) + unusedBlocks.add(block); + } + + for (auto block : unusedBlocks) + block->removeAndDeallocate(); + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); auto defVal = builder.emitDefaultConstructRaw((IRType*)intermediateType); builder.emitStore(outIntermediary, defVal); @@ -503,13 +537,17 @@ IRGlobalValueWithCode* DiffUnzipPass::extractPrimalFunc( if (auto structKeyDecor = inst->findDecoration<IRPrimalValueStructKeyDecoration>()) { builder.setInsertBefore(inst); - auto addr = builder.emitFieldAddress(builder.getPtrType(inst->getDataType()), intermediateVar, structKeyDecor->getStructKey()); + auto addr = builder.emitFieldAddress( + builder.getPtrType(inst->getDataType()), + intermediateVar, + structKeyDecor->getStructKey()); auto val = builder.emitLoad(addr); inst->replaceUsesWith(val); instsToRemove.add(inst); } } } + for (auto inst : instsToRemove) { inst->removeAndDeallocate(); @@ -517,6 +555,7 @@ IRGlobalValueWithCode* DiffUnzipPass::extractPrimalFunc( // Run simplification to DCE unnecessary insts. eliminateDeadCode(innerFunc); + eliminateDeadCode(specializedPrimalFunc); return primalFunc; } diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 35aa55dd3..2c55b390b 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -26,6 +26,13 @@ struct DiffUnzipPass Dictionary<IRInst*, IRInst*> primalMap; Dictionary<IRInst*, IRInst*> diffMap; + // First diff block. + // TODO: Can the same pass object can be used for multiple functions? + // might run into an issue here? + IRBlock* firstDiffBlock; + + // Dictionary<IRBlock*, List<IRBlock*>> + DiffUnzipPass(AutoDiffSharedContext* autodiffContext) : autodiffContext(autodiffContext), diffTypeContext(autodiffContext) { } @@ -58,34 +65,70 @@ struct DiffUnzipPass builder->setInsertInto(unzippedFunc); - // Work *only* with two-block functions for now. + // Functions need to have at least two blocks at this point (one for parameters, + // and atleast one for code) + // SLANG_ASSERT(unzippedFunc->getFirstBlock() != nullptr); SLANG_ASSERT(unzippedFunc->getFirstBlock()->getNextBlock() != nullptr); - SLANG_ASSERT(unzippedFunc->getFirstBlock()->getNextBlock()->getNextBlock() == nullptr); // Ignore the first block (this is reserved for parameters), start // at the second block. (For now, we work with only a single block of insts) // TODO: expand to handle multi-block functions later. + IRBlock* firstBlock = unzippedFunc->getFirstBlock()->getNextBlock(); - IRBlock* mainBlock = unzippedFunc->getFirstBlock()->getNextBlock(); + List<IRBlock*> mixedBlocks; + for (IRBlock* block = firstBlock; block; block = block->getNextBlock()) + { + // Only need to unzip blocks with both differential and primal instructions. + if (block->findDecoration<IRMixedDifferentialInstDecoration>()) + { + mixedBlocks.add(block); + } + } + + IRBlock* firstPrimalBlock = nullptr; - // Emit new blocks for split vesions of mainblock. - IRBlock* primalBlock = builder->emitBlock(); - IRBlock* diffBlock = builder->emitBlock(); + // Emit an empty primal block for every mixed block. + for (auto block : mixedBlocks) + { + IRBlock* primalBlock = builder->emitBlock(); + primalMap[block] = primalBlock; - // Mark the differential block as a differential inst. - builder->markInstAsDifferential(diffBlock); + if (block == firstBlock) + firstPrimalBlock = primalBlock; + } - // Split the main block into two. This method should also emit - // a branch statement from primalBlock to diffBlock. - // TODO: extend this code to split multiple blocks - // - splitBlock(mainBlock, primalBlock, diffBlock); + // Emit an empty differential block for every mixed block. + for (auto block : mixedBlocks) + { + IRBlock* diffBlock = builder->emitBlock(); + diffMap[block] = diffBlock; + + // Mark the differential block as a differential inst + // (and add a reference to the primal block) + builder->markInstAsDifferential(diffBlock, nullptr, primalMap[block]); + + // Record the first differential (code) block, + // since we want all 'return' insts in primal blocks + // to be replaced with a brahcn into this block. + // + if (block == firstBlock) + this->firstDiffBlock = diffBlock; + } + + // Split each block into two. + for (auto block : mixedBlocks) + { + splitBlock(block, as<IRBlock>(primalMap[block]), as<IRBlock>(diffMap[block])); + } + + // Swap the first block's occurences out for the first primal block. + firstBlock->replaceUsesWith(firstPrimalBlock); + + // Remove old blocks. + for (auto block : mixedBlocks) + block->removeAndDeallocate(); - // Replace occurences of mainBlock with primalBlock - mainBlock->replaceUsesWith(primalBlock); - mainBlock->removeAndDeallocate(); - return unzippedFunc; } @@ -221,10 +264,14 @@ struct DiffUnzipPass return InstPair(primalBuilder->emitVar(primalType), diffBuilder->emitVar(diffType)); } - InstPair splitReturn(IRBuilder*, IRBuilder* diffBuilder, IRReturn* mixedReturn) + InstPair splitReturn(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRReturn* mixedReturn) { auto pairType = as<IRDifferentialPairType>(mixedReturn->getVal()->getDataType()); auto primalType = pairType->getValueType(); + + // Check that we have an unambiguous 'first' differential block. + SLANG_ASSERT(firstDiffBlock); + auto primalBranch = primalBuilder->emitBranch(firstDiffBlock); auto pairVal = diffBuilder->emitMakeDifferentialPair( pairType, @@ -235,7 +282,81 @@ struct DiffUnzipPass auto returnInst = diffBuilder->emitReturn(pairVal); diffBuilder->markInstAsDifferential(returnInst, primalType); - return InstPair(nullptr, returnInst); + return InstPair(primalBranch, returnInst); + } + + InstPair splitControlFlow(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* branchInst) + { + switch (branchInst->getOp()) + { + case kIROp_unconditionalBranch: + { + auto uncondBranchInst = as<IRUnconditionalBranch>(branchInst); + auto targetBlock = uncondBranchInst->getTargetBlock(); + + // Split args. + List<IRInst*> primalArgs; + List<IRInst*> diffArgs; + for (UIndex ii = 0; ii < uncondBranchInst->getArgCount(); ii++) + { + if (isDifferentialInst(uncondBranchInst->getArg(ii))) + diffArgs.add(uncondBranchInst->getArg(ii)); + else + primalArgs.add(uncondBranchInst->getArg(ii)); + } + + return InstPair( + primalBuilder->emitBranch( + as<IRBlock>(primalMap[targetBlock]), + primalArgs.getCount(), + primalArgs.getBuffer()), + diffBuilder->emitBranch( + as<IRBlock>(diffMap[targetBlock]), + diffArgs.getCount(), + diffArgs.getBuffer())); + + } + + case kIROp_conditionalBranch: + { + auto trueBlock = as<IRConditionalBranch>(branchInst)->getTrueBlock(); + auto falseBlock = as<IRConditionalBranch>(branchInst)->getFalseBlock(); + auto condInst = as<IRConditionalBranch>(branchInst)->getCondition(); + + return InstPair( + primalBuilder->emitBranch( + condInst, + as<IRBlock>(primalMap[trueBlock]), + as<IRBlock>(primalMap[falseBlock])), + diffBuilder->emitBranch( + condInst, + as<IRBlock>(diffMap[trueBlock]), + as<IRBlock>(diffMap[falseBlock]))); + } + + case kIROp_ifElse: + { + auto trueBlock = as<IRIfElse>(branchInst)->getTrueBlock(); + auto falseBlock = as<IRIfElse>(branchInst)->getFalseBlock(); + auto afterBlock = as<IRIfElse>(branchInst)->getAfterBlock(); + auto condInst = as<IRIfElse>(branchInst)->getCondition(); + + return InstPair( + primalBuilder->emitIfElse( + condInst, + as<IRBlock>(primalMap[trueBlock]), + as<IRBlock>(primalMap[falseBlock]), + as<IRBlock>(primalMap[afterBlock])), + diffBuilder->emitIfElse( + condInst, + as<IRBlock>(diffMap[trueBlock]), + as<IRBlock>(diffMap[falseBlock]), + as<IRBlock>(diffMap[afterBlock]))); + } + + default: + SLANG_UNEXPECTED("Unhandled instruction"); + } } InstPair _splitMixedInst(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* inst) @@ -257,6 +378,15 @@ struct DiffUnzipPass case kIROp_Return: return splitReturn(primalBuilder, diffBuilder, as<IRReturn>(inst)); + case kIROp_unconditionalBranch: + case kIROp_conditionalBranch: + case kIROp_ifElse: + return splitControlFlow(primalBuilder, diffBuilder, inst); + + case kIROp_Unreachable: + return InstPair(primalBuilder->emitUnreachable(), + diffBuilder->emitUnreachable()); + default: SLANG_ASSERT_FAILURE("Unhandled mixed diff inst"); } @@ -270,7 +400,7 @@ struct DiffUnzipPass diffMap[inst] = instPair.differential; } - void splitBlock(IRBlock* mainBlock, IRBlock* primalBlock, IRBlock* diffBlock) + void splitBlock(IRBlock* block, IRBlock* primalBlock, IRBlock* diffBlock) { // Make two builders for primal and differential blocks. IRBuilder primalBuilder; @@ -282,12 +412,13 @@ struct DiffUnzipPass diffBuilder.setInsertInto(diffBlock); List<IRInst*> splitInsts; - for (auto child = mainBlock->getFirstChild(); child;) + for (auto child = block->getFirstChild(); child;) { IRInst* nextChild = child->getNextInst(); if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(child)) { + // Replace GetDiff(A) with A.d if (diffMap.ContainsKey(getDiffInst->getBase())) { getDiffInst->replaceUsesWith(lookupDiffInst(getDiffInst->getBase())); @@ -296,9 +427,9 @@ struct DiffUnzipPass continue; } } - - if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(child)) + else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(child)) { + // Replace GetPrimal(A) with A.p if (primalMap.ContainsKey(getPrimalInst->getBase())) { getPrimalInst->replaceUsesWith(lookupPrimalInst(getPrimalInst->getBase())); @@ -339,12 +470,12 @@ struct DiffUnzipPass } // Nothing should be left in the original block. - SLANG_ASSERT(mainBlock->getFirstChild() == nullptr); + SLANG_ASSERT(block->getFirstChild() == nullptr); // Branch from primal to differential block. // Functionally, the new blocks should produce the same output as the // old block. - primalBuilder.emitBranch(diffBlock); + // primalBuilder.emitBranch(diffBlock); } }; diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index f0ec1542e..40c24d11d 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -421,6 +421,32 @@ void stripAutoDiffDecorations(IRModule* module) stripAutoDiffDecorationsFromChildren(module->getModuleInst()); } + +void stripBlockTypeDecorations(IRFunc* func) +{ + for (auto child : func->getChildren()) + { + if (auto block = as<IRBlock>(child)) + { + for (auto decor = block->getFirstDecoration(); decor; ) + { + auto next = decor->getNextDecoration(); + switch (decor->getOp()) + { + case kIROp_DifferentialInstDecoration: + case kIROp_MixedDifferentialInstDecoration: + decor->removeAndDeallocate(); + break; + default: + break; + } + decor = next; + } + } + } +} + + struct StripNoDiffTypeAttributePass : InstPassBase { StripNoDiffTypeAttributePass(IRModule* module) : @@ -484,7 +510,7 @@ struct AutoDiffPass : public InstPassBase { bool changed = false; List<IRInst*> autoDiffWorkList; - // Collect all `ForwardDifferentiate` insts from the module. + // Collect all `ForwardDifferentiate`/`BackwardDifferentiate` insts from the module. autoDiffWorkList.clear(); processAllInsts([&](IRInst* inst) { @@ -541,6 +567,7 @@ struct AutoDiffPass : public InstPassBase // Run transcription logic to generate the body of forward/backward derivatives functions. // While doing so, we may discover new functions to differentiate, so we keep running until // the worklist goes dry. + List<IRFunc*> autodiffCleanupList; while (autodiffContext->followUpFunctionsToTranscribe.getCount() != 0) { changed = true; @@ -549,6 +576,14 @@ struct AutoDiffPass : public InstPassBase { auto diffFunc = as<IRFunc>(task.resultFunc); SLANG_ASSERT(diffFunc); + + // We're running in to some situations where the follow-up task + // has already been completed (diffFunc has been generated, processed, + // and deallocated). Skip over these for now. + // + if (!diffFunc->getDataType()) + continue; + auto primalFunc = as<IRFunc>(task.originalFunc); SLANG_ASSERT(primalFunc); switch (task.type) @@ -562,12 +597,26 @@ struct AutoDiffPass : public InstPassBase default: break; } + + autodiffCleanupList.add(diffFunc); } } + + // Get rid of block-level decorations that are used to keep track of + // different block types. These don't work well with the IR simplification + // passes since they don't expect decorations in blocks. + // + for (auto diffFunc : autodiffCleanupList) + stripBlockTypeDecorations(diffFunc); + + autodiffCleanupList.clear(); + if (!changed) break; hasChanges |= changed; } + + return hasChanges; } diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 6373334bf..03a3fb063 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -608,6 +608,7 @@ struct IRDifferentialInstDecoration : IRDecoration IR_LEAF_ISA(DifferentialInstDecoration) IRType* getPrimalType() { return as<IRType>(getOperand(0)); } + IRInst* getPrimalInst() { return as<IRInst>(getOperand(1)); } }; struct IRPrimalValueStructKeyDecoration : IRDecoration @@ -3423,6 +3424,11 @@ public: addDecoration(value, kIROp_DifferentialInstDecoration, primalType); } + void markInstAsDifferential(IRInst* value, IRType* primalType, IRInst* primalInst) + { + addDecoration(value, kIROp_DifferentialInstDecoration, primalType, primalInst); + } + void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable) { addDecoration(value, kIROp_COMWitnessDecoration, &witnessTable, 1); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 33130cfb3..d8a8fb7c4 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6623,6 +6623,10 @@ namespace Slang case kIROp_Reinterpret: case kIROp_GetNativePtr: return false; + + case kIROp_ForwardDifferentiate: + case kIROp_BackwardDifferentiate: + return false; } } |
