diff options
| -rw-r--r-- | docs/user-guide/07-autodiff.md | 4 | ||||
| -rw-r--r-- | source/slang/core.meta.slang | 3 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 67 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 38 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 94 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 279 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 62 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 4 | ||||
| -rw-r--r-- | tests/autodiff/reverse-loop.slang | 8 | ||||
| -rw-r--r-- | tests/autodiff/reverse-loop.slang.expected.txt | 6 |
16 files changed, 478 insertions, 124 deletions
diff --git a/docs/user-guide/07-autodiff.md b/docs/user-guide/07-autodiff.md index d72299229..488800a98 100644 --- a/docs/user-guide/07-autodiff.md +++ b/docs/user-guide/07-autodiff.md @@ -567,6 +567,6 @@ The compiler can generate forward derivative and backward propagation implementa - No access to global variables or shader parameters within a differentiable function. - All operations to global resources, including texture reads or atomic writes, are treating as a non-differentiable operation. - If a differentiable function contains calls that cause side-effects such as updates to global memory, there will not be a guarantee on how many times the side-effect will occur during the resulting derivative function or back-propagation function. -- All loops in a backward differentiable function must end within a statically known number of iterations. If the maximum number of iterations is not trivially deductible by the type system as a compile-time constant, a manually attribute is needed at the loop to provide the number. If the number of actually executed iterations exceeds what is being specified, the resulting runtime behavior is undefined. +- `for` loops: In a backward differentiable function, loops currently cannot have `continue` statements although `break` statements are supported. Loops must use the attribute `[MaxIters(<count>)]` to specify a maximum number of iterations. This will be used by compiler to allocate space to store intermediate data. If the actual number of iterations exceeds the provided maximum, the behavior is undefined. -The above restrictions do no apply if a user-defined derivative or backward propagation function is provided. +The above restrictions do not apply if a user-defined derivative or backward propagation function is provided. diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 31dd5ed29..533713016 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2823,6 +2823,9 @@ attribute_syntax [fastopt] : FastOptAttribute; __attributeTarget(LoopStmt) attribute_syntax [allow_uav_condition] : AllowUAVConditionAttribute; +__attributeTarget(LoopStmt) +attribute_syntax [MaxIters(count)] : MaxItersAttribute; + __attributeTarget(IfStmt) attribute_syntax [flatten] : FlattenAttribute; diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 666ca77ea..42b79ca4a 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -620,6 +620,14 @@ class UnrollAttribute : public Attribute IntegerLiteralValue getCount(); }; +// An `[maxiters(count)]` +class MaxItersAttribute : public Attribute +{ + SLANG_AST_CLASS(MaxItersAttribute) + + int32_t value = 0; +}; + class LoopAttribute : public Attribute { SLANG_AST_CLASS(LoopAttribute) diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index e73f04301..9f3e79978 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -507,6 +507,17 @@ namespace Slang // as 1 arg if nothing is specified) SLANG_ASSERT(attr->args.getCount() == 1); } + else if (auto maxItersAttrs = as<MaxItersAttribute>(attr)) + { + if (auto cint = checkConstantIntVal(attr->args[0])) + { + maxItersAttrs->value = (int32_t) cint->value; + } + else + { + getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), 1); + } + } else if (auto userDefAttr = as<UserDefinedAttribute>(attr)) { // check arguments against attribute parameters defined in attribClassDecl diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 04acad435..fca34f9a2 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -149,6 +149,7 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns builder->markInstAsDifferential(diffSub, resultType); auto diffMul = builder->emitMul(resultType, primalRight, primalRight); + builder->markInstAsPrimal(diffMul); auto diffDiv = builder->emitDiv(diffType, diffSub, diffMul); builder->markInstAsDifferential(diffDiv, resultType); @@ -881,6 +882,29 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI return InstPair(primalUpdateField, diffUpdateElement); } +List<IRInst*> ForwardDiffTranscriber::transcribePhiArgs(IRBuilder* builder, List<IRInst*> origPhiArgs) +{ + // Grab the differentials for any phi nodes. + List<IRInst*> newArgs; + for (auto origArg : origPhiArgs) + { + auto primalArg = lookupPrimalInst(builder, origArg); + newArgs.add(primalArg); + + if (differentiateType(builder, origArg->getDataType())) + { + auto diffArg = lookupDiffInst(origArg, nullptr); + if (diffArg) + newArgs.add(diffArg); + else + newArgs.add( + getDifferentialZeroOfType(builder, origArg->getDataType())); + } + } + + return newArgs; +} + InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* origLoop) { // The loop comes with three blocks.. we just need to transcribe each one @@ -902,13 +926,14 @@ InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* orig diffLoopOperands.add(diffTargetBlock); diffLoopOperands.add(diffBreakBlock); diffLoopOperands.add(diffContinueBlock); - - // If there are any other operands, use their primal versions. + + List<IRInst*> phiArgs; for (UIndex ii = diffLoopOperands.getCount(); ii < origLoop->getOperandCount(); ii++) - { - auto primalOperand = findOrTranscribePrimalInst(builder, origLoop->getOperand(ii)); - diffLoopOperands.add(primalOperand); - } + phiArgs.add(origLoop->getOperand(ii)); + + auto newPhiArgs = transcribePhiArgs(builder, phiArgs); + for (auto newArg : newPhiArgs) + diffLoopOperands.add(newArg); IRInst* diffLoop = builder->emitIntrinsicInst( nullptr, @@ -917,6 +942,9 @@ InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* orig diffLoopOperands.getBuffer()); builder->markInstAsMixedDifferential(diffLoop); + if (auto maxItersDecoration = origLoop->findDecoration<IRLoopMaxItersDecoration>()) + builder->addLoopMaxItersDecoration(diffLoop, maxItersDecoration->getMaxIters()); + return InstPair(diffLoop, diffLoop); } @@ -1211,6 +1239,28 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I return diffFunc; } +void ForwardDiffTranscriber::checkAutodiffInstDecorations(IRFunc* fwdFunc) +{ + for (auto block = fwdFunc->getFirstBlock(); block; block = block->getNextBlock()) + { + for (auto inst = block->getFirstOrdinaryInst(); inst; inst = inst->getNextInst()) + { + // TODO: Special case, not sure why these insts show up + if (as<IRUndefined>(inst)) continue; + + List<IRDecoration*> decorations; + for (auto decoration : inst->getDecorations()) + { + if (as<IRAutodiffInstDecoration>(decoration)) + decorations.add(decoration); + } + + // Must have _exactly_ one autodiff tag. + SLANG_ASSERT(decorations.getCount() == 1); + } + } +} + // Transcribe a function definition. InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) { @@ -1266,6 +1316,10 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr } } } + +#if _DEBUG + checkAutodiffInstDecorations(diffFunc); +#endif return InstPair(primalFunc, diffFunc); } @@ -1310,7 +1364,6 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_MakeMatrix: case kIROp_MakeMatrixFromScalar: case kIROp_MatrixReshape: - case kIROp_VectorReshape: case kIROp_IntCast: case kIROp_FloatCast: case kIROp_MakeVectorFromScalar: diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 260b0a433..e80b25754 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -94,6 +94,10 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase // Transcribe a function without marking the result as a decoration. IRFunc* transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc); + List<IRInst*> transcribePhiArgs(IRBuilder* builder, List<IRInst*> origPhiArgs); + + void checkAutodiffInstDecorations(IRFunc* fwdFunc); + // Create an empty func to represent the transcribed func of `origFunc`. virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 702f9819a..20090ca42 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -765,7 +765,7 @@ namespace Slang // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the // derivative of the return value. - DiffTransposePass::FuncTranspositionInfo info = { paramTransposeInfo.dOutParam, nullptr}; + DiffTransposePass::FuncTranspositionInfo info = { paramTransposeInfo.dOutParam }; diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info); eliminateDeadCode(diffPropagateFunc); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 31a3072c0..10a734d65 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -901,6 +901,15 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst { mapPrimalInst(origInst, pair.primal); mapDifferentialInst(origInst, pair.differential); + + + if (pair.primal != pair.differential && + !pair.primal->findDecoration<IRAutodiffInstDecoration>() && + !as<IRConstant>(pair.primal)) + { + builder->markInstAsPrimal(pair.primal); + } + if (pair.differential) { switch (pair.differential->getOp()) @@ -920,16 +929,27 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); } - // Tag the differential inst using a decoration (if it doesn't have one) - if (!pair.differential->findDecoration<IRDifferentialInstDecoration>() && - !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>() && - !as<IRConstant>(pair.differential)) + // Automatically tag the primal and differential results + // if they haven't already been handled by the + // code. + // + if (pair.primal != pair.differential) + { + if (!pair.differential->findDecoration<IRAutodiffInstDecoration>() + && !as<IRConstant>(pair.differential)) + { + auto primalType = as<IRType>(pair.primal->getDataType()); + builder->markInstAsDifferential(pair.differential, primalType); + } + } + else { - // TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential - // instead. - // - auto primalType = as<IRType>(pair.primal->getDataType()); - builder->markInstAsDifferential(pair.differential, primalType); + if (!pair.primal->findDecoration<IRAutodiffInstDecoration>() + && !as<IRConstant>(pair.differential)) + { + auto mixedType = as<IRType>(pair.primal->getDataType()); + builder->markInstAsMixedDifferential(pair.primal, mixedType); + } } break; diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index d9b28ea3c..2953c6206 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -78,11 +78,6 @@ struct DiffTransposePass // of the *output* of the function. // IRInst* dOutInst; - - // Mapping between *primal* insts in the forward-mode function, and the - // reverse-mode function - // - Dictionary<IRInst*, IRInst*>* primalsMap; }; struct PendingBlockTerminatorEntry @@ -353,6 +348,13 @@ struct DiffTransposePass getPhiGrads(firstLoopBlock).getBuffer()); } + auto phiGrads = getPhiGrads(condBlock); + if (phiGrads.getCount() > 0) + { + revTrueBlock = insertPhiBlockBefore(revTrueBlock, phiGrads); + revFalseBlock = insertPhiBlockBefore(revFalseBlock, phiGrads); + } + // Emit condition into the new cond block. builder.setInsertInto(revCondBlock); builder.emitIfElse( @@ -533,8 +535,6 @@ struct DiffTransposePass // firstRevDiffBlockMap[revDiffFunc] = revBlockMap[terminalDiffBlocks[0]]; - IRInst* retVal = nullptr; - for (auto block : workList) { // Set dOutParameter as the transpose gradient for the return inst, if any. @@ -543,7 +543,6 @@ struct DiffTransposePass if (auto returnInst = as<IRReturn>(block->getTerminator())) { this->addRevGradientForFwdInst(returnInst, RevGradient(returnInst, transposeInfo.dOutInst, nullptr)); - retVal = returnInst->getVal(); } } @@ -572,6 +571,11 @@ struct DiffTransposePass auto terminalPrimalBlock = terminalPrimalBlocks[0]; auto firstRevBlock = as<IRBlock>(revBlockMap[terminalDiffBlocks[0]]); + auto returnDecoration = + terminalPrimalBlock->getTerminator()->findDecoration<IRBackwardDerivativePrimalReturnDecoration>(); + SLANG_ASSERT(returnDecoration); + auto retVal = returnDecoration->getBackwardDerivativePrimalReturnValue(); + terminalPrimalBlock->getTerminator()->removeAndDeallocate(); IRBuilder subBuilder(builder.getSharedBuilder()); @@ -582,15 +586,6 @@ struct DiffTransposePass auto branch = subBuilder.emitBranch(firstRevBlock); - if (!retVal || retVal->getOp() == kIROp_VoidLit) - { - retVal = subBuilder.getVoidValue(); - } - else - { - auto makePair = cast<IRMakeDifferentialPair>(retVal); - retVal = makePair->getPrimalValue(); - } subBuilder.addBackwardDerivativePrimalReturnDecoration(branch, retVal); } @@ -610,6 +605,25 @@ struct DiffTransposePass } } + IRInst* extractAccumulatorVarGradient(IRBuilder* builder, IRInst* fwdInst) + { + if (auto accVar = getOrCreateAccumulatorVar(fwdInst)) + { + auto gradValue = builder->emitLoad(accVar); + builder->emitStore( + accVar, + emitDZeroOfDiffInstType( + builder, + tryGetPrimalTypeFromDiffInst(fwdInst))); + + return gradValue; + } + else + { + return nullptr; + } + } + // Fetch or create a gradient accumulator var // corresponding to a inst. These are used to // accumulate gradients across blocks. @@ -688,6 +702,30 @@ struct DiffTransposePass } } + // Some special instructions simply need to be copied over. + // These do not deal with differentials. + // TODO: This will not work if there are any differential + // insts that rely on loop counter vars having a specific + // value. + // The solution is to have primal insts appearing in + // differential blocks be in their own special blocks that are + // ignored entirely, rather than dealing with them one inst + // at a time. + // + for (IRInst* child = fwdBlock->getFirstChild(); child;) + { + auto nextChild = child->getNextInst(); + + if (child->findDecoration<IRLoopCounterDecoration>()) + { + // Loop counter insts should not have any gradients. + SLANG_ASSERT(!hasRevGradients(child)); + child->insertAtEnd(revBlock); + } + + child = nextChild; + } + // Move pointer & reference insts to the top of the reverse-mode block. List<IRInst*> nonValueInsts; for (IRInst* child = fwdBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) @@ -719,6 +757,7 @@ struct DiffTransposePass if (as<IRDecoration>(child) || as<IRParam>(child)) continue; + transposeInst(&builder, child); } @@ -744,10 +783,10 @@ struct DiffTransposePass // if (isInstUsedOutsideParentBlock(param)) { - auto accVar = getOrCreateAccumulatorVar(param); + auto accGradient = extractAccumulatorVarGradient(&builder, param); addRevGradientForFwdInst( param, - RevGradient(param, builder.emitLoad(accVar), nullptr)); + RevGradient(param, accGradient, nullptr)); } if (hasRevGradients(param)) @@ -839,15 +878,6 @@ struct DiffTransposePass break; } - // Some special instructions simply need to be copied over. - // These do not deal with differentials. - // - if (inst->findDecoration<IRLoopCounterDecoration>()) - { - inst->insertAtEnd(builder->getBlock()); - return; - } - // Look for gradient entries for this inst. List<RevGradient> gradients; if (hasRevGradients(inst)) @@ -898,9 +928,9 @@ struct DiffTransposePass // if (isInstUsedOutsideParentBlock(inst) && !as<IRLoad>(inst)) { - auto accVar = getOrCreateAccumulatorVar(inst); + auto accGradient = extractAccumulatorVarGradient(builder, inst); gradients.add( - RevGradient(inst, builder->emitLoad(accVar), nullptr)); + RevGradient(inst, accGradient, nullptr)); } // Emit the aggregate of all the gradients here. @@ -2399,8 +2429,6 @@ struct DiffTransposePass Dictionary<IRInst*, IRVar*> revAccumulatorVarMap; - Dictionary<IRInst*, IRInst*>* primalsMap; - List<IRInst*> usedPtrs; Dictionary<IRBlock*, IRBlock*> revBlockMap; @@ -2412,6 +2440,8 @@ struct DiffTransposePass List<PendingBlockTerminatorEntry> pendingBlocks; Dictionary<IRBlock*, List<IRInst*>> phiGradsMap; + + Dictionary<IRBlock*, IRBlock*> initializerBlockMap; }; diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 057ff53c4..1a85ea6a4 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -232,6 +232,24 @@ struct DiffUnzipPass // lowerIndexedRegions(); + // Copy regions from fwd-block to their split blocks + // to make it easier to do lookups. + // + { + List<IRBlock*> workList; + for (auto blockRegionPair : indexRegionMap) + { + IRBlock* block = blockRegionPair.Key; + workList.add(block); + } + + for (auto block : workList) + { + indexRegionMap[as<IRBlock>(primalMap[block])] = (IndexedRegion*)indexRegionMap[block]; + indexRegionMap[as<IRBlock>(diffMap[block])] = (IndexedRegion*)indexRegionMap[block]; + } + } + // Process intermediate insts in indexed blocks // into array loads/stores. // @@ -262,19 +280,44 @@ struct DiffUnzipPass IRBlock* getUpdateBlock(IndexedRegion* region) { + // TODO: What if the 'continue' region has multiple + // blocks? + // We ideally want the _last_ block before control loops back. + // + SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>( + region->continueBlock->getTerminator())->getTargetBlock() == region->firstBlock); + return region->continueBlock; } + + IRBlock* getFirstLoopBodyBlock(IndexedRegion* region) + { + // Grab the 'condition' block. + auto condBlock = region->firstBlock; + + SLANG_RELEASE_ASSERT(as<IRIfElse>(condBlock->getTerminator())); + + return as<IRIfElse>(condBlock->getTerminator())->getTrueBlock(); + } void tryInferMaxIndex(IndexedRegion* region) { if (region->status != IndexedRegion::CountStatus::Unresolved) return; + + auto loop = as<IRLoop>(region->initBlock->getTerminator()); - // We're going to fix this at a some random number - // for now, and then add some basic inference + user-defined decoration - // - region->maxIters = 5; - region->status = IndexedRegion::CountStatus::Static; + if (auto maxItersDecoration = loop->findDecoration<IRLoopMaxItersDecoration>()) + { + region->maxIters = (Count) maxItersDecoration->getMaxIters(); + region->status = IndexedRegion::CountStatus::Static; + } + + if (region->status == IndexedRegion::CountStatus::Unresolved) + { + SLANG_UNEXPECTED("Could not resolve max iters \ + for loop appearing in reverse-mode"); + } } // Make a primal value *available* to the differential block. @@ -297,22 +340,49 @@ struct DiffUnzipPass for (auto region : indexRegions) { - IRBlock* initializerBlock = getInitializerBlock(region); + //IRBlock* initializerBlock = getInitializerBlock(region); + IRBlock* breakBlock = region->breakBlock; // Grab first primal block. - auto firstPrimalBlock = primalMap[region->breakBlock->getParent()->getFirstBlock()->getNextBlock()]; + IRBlock* firstPrimalBlock = as<IRBlock>(primalMap[region->breakBlock->getParent()->getFirstBlock()->getNextBlock()]); // Make variable in the top-most block (so it's visible to diff blocks) - builder.setInsertInto(firstPrimalBlock); - region->primalCountVar = builder.emitVar(builder.getUIntType()); - - // Make another variable in the diff block initialized to the - // final value of the primal counter. + builder.setInsertBefore(firstPrimalBlock->getTerminator()); + region->primalCountVar = builder.emitVar(builder.getIntType()); + builder.emitStore( + region->primalCountVar, + builder.getIntValue(builder.getIntType(), 0)); + + // NOTE: This is a hacky shortcut we're taking here. + // Technically the unzip pass should not affect the + // correctness (it must still compute the proper fwd-mode derivative) + // However, we're currently making the loop counter go backwards to + // make it easier on the transposition pass, so the output from + // the unzip pass is neither fwd-mode or rev-mode until the transposition + // step is complete. + // + // TODO: Ideally this needs to be replaced with a small inversion step + // within the transposition pass. + // + // Emit the diff counter into the diff *break* block ( + // which we're praying turns into the reverse initializer block) + // initialized to the final value of the primal counter. // - builder.setInsertInto(diffMap[initializerBlock]); - auto primalCounterValue = builder.emitLoad(region->primalCountVar); - region->diffCountVar = builder.emitVar(builder.getUIntType()); - builder.emitStore(region->diffCountVar, primalCounterValue); + builder.setInsertBefore(as<IRBlock>(diffMap[breakBlock])->getTerminator()); + //auto primalCounterValue = builder.emitLoad(region->primalCountVar); + auto primalCounterCurrValue = builder.emitLoad(region->primalCountVar); + auto primalCounterLastValue = builder.emitSub( + primalCounterCurrValue->getDataType(), + primalCounterCurrValue, + builder.getIntValue(builder.getIntType(), 1)); + + region->diffCountVar = builder.emitVar(builder.getIntType()); + auto diffCountInit = builder.emitStore(region->diffCountVar, primalCounterLastValue); + + builder.addLoopCounterDecoration(diffCountInit); + builder.addLoopCounterDecoration(region->diffCountVar); + builder.addLoopCounterDecoration(primalCounterCurrValue); + builder.addLoopCounterDecoration(primalCounterLastValue); IRBlock* updateBlock = getUpdateBlock(region); @@ -324,9 +394,9 @@ struct DiffUnzipPass auto counterVal = builder.emitLoad(region->primalCountVar); auto incCounterVal = builder.emitAdd( - builder.getUIntType(), + builder.getIntType(), counterVal, - builder.getIntValue(builder.getUIntType(), 1)); + builder.getIntValue(builder.getIntType(), 1)); auto incStore = builder.emitStore(region->primalCountVar, incCounterVal); @@ -336,25 +406,16 @@ struct DiffUnzipPass } { - // NOTE: This is a hacky shortcut we're taking here. - // Technically the unzip pass should not affect the - // correctness (it must still compute the proper fwd-mode derivative) - // However, we're currently making the loop counter go backwards to - // make it easier on the transposition pass, so the output from - // the unzip pass is neither fwd-mode or rev-mode until the transposition - // step is complete. - // - // TODO: Ideally this needs to be replaced with a small inversion step - // within the transposition pass. - // + IRBlock* firstLoopBlock = getFirstLoopBodyBlock(region); + auto diffFirstLoopBlock = as<IRBlock>(diffMap[firstLoopBlock]); - builder.setInsertBefore(as<IRBlock>(diffMap[updateBlock])->getTerminator()); + builder.setInsertBefore(diffFirstLoopBlock->getTerminator()); auto counterVal = builder.emitLoad(region->diffCountVar); auto decCounterVal = builder.emitSub( - builder.getUIntType(), + builder.getIntType(), counterVal, - builder.getIntValue(builder.getUIntType(), 0)); + builder.getIntValue(builder.getIntType(), 1)); auto decStore = builder.emitStore(region->diffCountVar, decCounterVal); @@ -363,6 +424,27 @@ struct DiffUnzipPass builder.addLoopCounterDecoration(counterVal); builder.addLoopCounterDecoration(decCounterVal); builder.addLoopCounterDecoration(decStore); + + // TODO: + // This is another hack here to avoid the counter from going negative + // (since they are not valid indices) + // + IRBlock* diffCondBlock = as<IRBlock>(diffMap[region->firstBlock]); + + builder.setInsertBefore(diffCondBlock->getTerminator()); + IRInst* diffCounterVal = builder.emitLoad(region->diffCountVar); + IRInst* diffCounterCmp = builder.emitIntrinsicInst( + builder.getBoolType(), + kIROp_Geq, + 2, + List<IRInst*>( + diffCounterVal, + builder.getIntValue(builder.getIntType(), 0)).getBuffer()); + + as<IRIfElse>(diffCondBlock->getTerminator())->condition.set(diffCounterCmp); + + builder.addLoopCounterDecoration(diffCounterVal); + builder.addLoopCounterDecoration(diffCounterCmp); } } @@ -394,6 +476,7 @@ struct DiffUnzipPass for (; region; region = region->parent) regions.add(region); } + for (auto inst : primalInsts) { @@ -407,6 +490,7 @@ struct DiffUnzipPass if (isDifferentialInst(useBlock)) { shouldStore = true; + break; } } @@ -439,52 +523,111 @@ struct DiffUnzipPass auto storageVar = builder.emitVar(arrayType); // 3. Store current value into the array and replace uses with a load. + // TODO: If an index is missing, use the 'last' value of the primal index. { builder.setInsertAfter(inst); IRInst* storeAddr = storageVar; - IRType* currType = storageVar->getDataType(); + IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType(); for (auto region : regions) { currType = as<IRArrayType>(currType)->getElementType(); storeAddr = builder.emitElementAddress( - currType, + builder.getPtrType(currType), storeAddr, - region->primalCountVar); + builder.emitLoad(region->primalCountVar)); } builder.emitStore(storeAddr, inst); } // 4. Replace uses in differential blocks with loads from the array. + List<IRInst*> instsToTag; { + List<IRUse*> diffUses; for (auto use = inst->firstUse; use; use = use->nextUse) + { + if (as<IRDecoration>(use->getUser())) + continue; + + IRBlock* useBlock = as<IRBlock>(use->getUser()->getParent()); + if (useBlock && isDifferentialInst(useBlock)) + diffUses.add(use); + } + + for (auto use : diffUses) { IRBlock* useBlock = as<IRBlock>(use->getUser()->getParent()); + builder.setInsertBefore(use->getUser()); + + IRInst* loadAddr = storageVar; + IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType(); - if (isDifferentialInst(useBlock)) + // Enumerate use block regions. + // TODO: Probably a good idea to do this ahead of time for + // all blocks. + // + List<IndexedRegion*> useBlockRegions; { - builder.setInsertBefore(use->getUser()); + IndexedRegion* region = indexRegionMap.ContainsKey(useBlock) ? + (IndexedRegion*)indexRegionMap[useBlock] : nullptr; + for (; region; region = region->parent) + useBlockRegions.add(region); + } - IRInst* loadAddr = storageVar; - IRType* currType = storageVar->getDataType(); + for (auto region : regions) + { + currType = as<IRArrayType>(currType)->getElementType(); + if (useBlockRegions.contains(region)) + { + // If the use-block is under the same region, use the + // differential counter variable + // + auto diffCounterCurrValue = builder.emitLoad(region->diffCountVar); + instsToTag.add(diffCounterCurrValue); - for (auto region : regions) + loadAddr = builder.emitElementAddress( + builder.getPtrType(currType), + loadAddr, + diffCounterCurrValue); + } + else { - currType = as<IRArrayType>(currType)->getElementType(); + // If the use-block is outside this region, use the + // last available value (by indexing with primal counter minus 1) + // + auto primalCounterCurrValue = builder.emitLoad(region->primalCountVar); + auto primalCounterLastValue = builder.emitSub( + primalCounterCurrValue->getDataType(), + primalCounterCurrValue, + builder.getIntValue(builder.getIntType(), 1)); + + instsToTag.add(primalCounterCurrValue); + instsToTag.add(primalCounterLastValue); loadAddr = builder.emitElementAddress( - currType, + builder.getPtrType(currType), loadAddr, - region->diffCountVar); + primalCounterLastValue); } - use->set(builder.emitLoad(loadAddr)); + instsToTag.add(loadAddr); } + + auto loadedValue = builder.emitLoad(loadAddr); + instsToTag.add(loadedValue); + + use->set(loadedValue); } } + + // TODO: Loop-counter is not really the right decoration.. + // replace with primal-inst when it's ready. + // + for (auto instToTag : instsToTag) + builder.addLoopCounterDecoration(instToTag); } } @@ -710,7 +853,11 @@ struct DiffUnzipPass // Check that we have an unambiguous 'first' differential block. SLANG_ASSERT(firstDiffBlock); + auto primalBranch = primalBuilder->emitBranch(firstDiffBlock); + primalBuilder->addBackwardDerivativePrimalReturnDecoration( + primalBranch, lookupPrimalInst(mixedReturn->getVal())); + auto pairVal = diffBuilder->emitMakeDifferentialPair( pairType, lookupPrimalInst(mixedReturn->getVal()), @@ -726,6 +873,9 @@ struct DiffUnzipPass { // If return value is not differentiable, just turn it into a trivial branch. auto primalBranch = primalBuilder->emitBranch(firstDiffBlock); + primalBuilder->addBackwardDerivativePrimalReturnDecoration( + primalBranch, primalBuilder->getVoidValue()); + auto returnInst = diffBuilder->emitReturn(); diffBuilder->markInstAsDifferential(returnInst, nullptr); return InstPair(primalBranch, returnInst); @@ -903,15 +1053,38 @@ struct DiffUnzipPass // Push a new index. addNewIndex(mixedLoop); - return InstPair( - primalBuilder->emitLoop( - as<IRBlock>(primalMap[nextBlock]), - as<IRBlock>(primalMap[breakBlock]), - as<IRBlock>(primalMap[continueBlock])), - diffBuilder->emitLoop( - as<IRBlock>(diffMap[nextBlock]), - as<IRBlock>(diffMap[breakBlock]), - as<IRBlock>(diffMap[continueBlock]))); + // Split args. + List<IRInst*> primalArgs; + List<IRInst*> diffArgs; + for (UIndex ii = 0; ii < mixedLoop->getArgCount(); ii++) + { + if (isDifferentialInst(mixedLoop->getArg(ii))) + diffArgs.add(mixedLoop->getArg(ii)); + else + primalArgs.add(mixedLoop->getArg(ii)); + } + + auto primalLoop = primalBuilder->emitLoop( + as<IRBlock>(primalMap[nextBlock]), + as<IRBlock>(primalMap[breakBlock]), + as<IRBlock>(primalMap[continueBlock]), + primalArgs.getCount(), + primalArgs.getBuffer()); + + auto diffLoop = diffBuilder->emitLoop( + as<IRBlock>(diffMap[nextBlock]), + as<IRBlock>(diffMap[breakBlock]), + as<IRBlock>(diffMap[continueBlock]), + diffArgs.getCount(), + diffArgs.getBuffer()); + + if (auto maxItersDecoration = mixedLoop->findDecoration<IRLoopMaxItersDecoration>()) + { + primalBuilder->addLoopMaxItersDecoration(primalLoop, maxItersDecoration->getMaxIters()); + diffBuilder->addLoopMaxItersDecoration(diffLoop, maxItersDecoration->getMaxIters()); + } + + return InstPair(primalLoop, diffLoop); } InstPair splitControlFlow(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* branchInst) diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index f38bdfdbd..2ce5a48f7 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -460,6 +460,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_ForwardDerivativeDecoration: case kIROp_DerivativeMemberDecoration: case kIROp_DifferentiableTypeDictionaryDecoration: + case kIROp_PrimalInstDecoration: case kIROp_DifferentialInstDecoration: case kIROp_MixedDifferentialInstDecoration: case kIROp_BackwardDerivativeDecoration: diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 26a92a17a..e627c575d 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -598,6 +598,7 @@ INST(GetOptiXSbtDataPtr, getOptiXSbtDataPointer, 0, 0) INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(LayoutDecoration, layout, 1, 0) INST(LoopControlDecoration, loopControl, 1, 0) + INST(LoopMaxItersDecoration, loopMaxIters, 1, 0) INST(IntrinsicOpDecoration, intrinsicOp, 1, 0) /* TargetSpecificDecoration */ INST(TargetDecoration, target, 1, 0) @@ -767,13 +768,19 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(LoopCounterDecoration, loopCounterDecoration, 0, 0) + /* Auto-diff inst decorations */ + /// Used by the auto-diff pass to mark insts that compute + /// a primal value. + INST(PrimalInstDecoration, primalInstDecoration, 0, 0) + /// Used by the auto-diff pass to mark insts that compute /// a differential value. - INST(DifferentialInstDecoration, diffInstDecoration, 1, 0) + INST(DifferentialInstDecoration, diffInstDecoration, 1, 0) /// Used by the auto-diff pass to mark insts that compute /// BOTH a differential and a primal value. - INST(MixedDifferentialInstDecoration, mixedDiffInstDecoration, 1, 0) + INST(MixedDifferentialInstDecoration, mixedDiffInstDecoration, 1, 0) + INST_RANGE(AutodiffInstDecoration, PrimalInstDecoration, MixedDifferentialInstDecoration) /// Used by the auto-diff pass to mark insts whose result is stored /// in an intermediary struct for reuse in backward propagation phase. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index fad20e900..2453b56a7 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -66,6 +66,14 @@ struct IRLoopControlDecoration : IRDecoration } }; +struct IRLoopMaxItersDecoration : IRDecoration +{ + enum { kOp = kIROp_LoopMaxItersDecoration }; + IR_LEAF_ISA(LoopMaxItersDecoration) + + IRConstant* getMaxItersInst() { return cast<IRConstant>(getOperand(0)); } + IRIntegerValue getMaxIters() { return as<IRIntLit>(getOperand(0))->getValue(); } +}; struct IRTargetSpecificDecoration : IRDecoration { @@ -672,7 +680,12 @@ struct IRLoopCounterDecoration : IRDecoration IR_LEAF_ISA(LoopCounterDecoration) }; -struct IRDifferentialInstDecoration : IRDecoration +struct IRAutodiffInstDecoration : IRDecoration +{ + IR_PARENT_ISA(AutodiffInstDecoration) +}; + +struct IRDifferentialInstDecoration : IRAutodiffInstDecoration { enum { @@ -686,41 +699,52 @@ struct IRDifferentialInstDecoration : IRDecoration IRInst* getPrimalInst() { return as<IRInst>(getOperand(1)); } }; -struct IRPrimalValueStructKeyDecoration : IRDecoration +struct IRPrimalInstDecoration : IRAutodiffInstDecoration { enum { - kOp = kIROp_PrimalValueStructKeyDecoration + kOp = kIROp_PrimalInstDecoration }; - IR_LEAF_ISA(PrimalValueStructKeyDecoration) + IR_LEAF_ISA(PrimalInstDecoration) +}; - IRStructKey* getStructKey() { return as<IRStructKey>(getOperand(0)); } + +struct IRMixedDifferentialInstDecoration : IRAutodiffInstDecoration +{ + enum + { + kOp = kIROp_MixedDifferentialInstDecoration + }; + + IRUse pairType; + IR_LEAF_ISA(MixedDifferentialInstDecoration) + + IRType* getPairType() { return as<IRType>(getOperand(0)); } }; -struct IRPrimalElementTypeDecoration : IRDecoration +struct IRPrimalValueStructKeyDecoration : IRDecoration { enum { - kOp = kIROp_PrimalElementTypeDecoration + kOp = kIROp_PrimalValueStructKeyDecoration }; - IR_LEAF_ISA(PrimalElementTypeDecoration) + IR_LEAF_ISA(PrimalValueStructKeyDecoration) - IRInst* getPrimalElementType() { return getOperand(0); } + IRStructKey* getStructKey() { return as<IRStructKey>(getOperand(0)); } }; -struct IRMixedDifferentialInstDecoration : IRDecoration +struct IRPrimalElementTypeDecoration : IRDecoration { enum { - kOp = kIROp_MixedDifferentialInstDecoration + kOp = kIROp_PrimalElementTypeDecoration }; - IRUse pairType; - IR_LEAF_ISA(MixedDifferentialInstDecoration) + IR_LEAF_ISA(PrimalElementTypeDecoration) - IRType* getPairType() { return as<IRType>(getOperand(0)); } + IRInst* getPrimalElementType() { return getOperand(0); } }; struct IRBackwardDifferentiableDecoration : IRDecoration @@ -3519,6 +3543,11 @@ public: addDecoration(value, kIROp_LoopControlDecoration, getIntValue(getIntType(), IRIntegerValue(mode))); } + void addLoopMaxItersDecoration(IRInst* value, IntegerLiteralValue iters) + { + addDecoration(value, kIROp_LoopMaxItersDecoration, getIntValue(getIntType(), iters)); + } + void addSemanticDecoration(IRInst* value, UnownedStringSlice const& text, int index = 0) { addDecoration(value, kIROp_SemanticDecoration, getStringValue(text), getIntValue(getIntType(), index)); @@ -3651,6 +3680,11 @@ public: addDecoration(value, kIROp_LoopCounterDecoration); } + void markInstAsPrimal(IRInst* value) + { + addDecoration(value, kIROp_PrimalInstDecoration); + } + void markInstAsDifferential(IRInst* value) { addDecoration(value, kIROp_DifferentialInstDecoration, nullptr); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 8377246fb..74f06557d 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4845,6 +4845,10 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> { getBuilder()->addLoopControlDecoration(inst, kIRLoopControl_Loop); } + else if( auto maxItersAttr = stmt->findModifier<MaxItersAttribute>() ) + { + getBuilder()->addLoopMaxItersDecoration(inst, maxItersAttr->value); + } // TODO: handle other cases here } diff --git a/tests/autodiff/reverse-loop.slang b/tests/autodiff/reverse-loop.slang index 46d707548..f6e951eab 100644 --- a/tests/autodiff/reverse-loop.slang +++ b/tests/autodiff/reverse-loop.slang @@ -1,4 +1,3 @@ -//TEST_IGNORE_FILE: //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj @@ -14,6 +13,7 @@ float test_simple_loop(float y) { float t = y; + [MaxIters(5)] for (int i = 0; i < 3; i++) { t = t * t; @@ -29,13 +29,13 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) dpfloat dpa = dpfloat(1.0, 0.0); __bwd_diff(test_simple_loop)(dpa, 1.0f); - outputBuffer[0] = dpa.d; // Expect: 2.0 + outputBuffer[0] = dpa.d; // Expect: 8.0 } { dpfloat dpa = dpfloat(0.4, 0.0); - + __bwd_diff(test_simple_loop)(dpa, 1.0f); - outputBuffer[1] = dpa.d; // Expect: 1.0 + outputBuffer[1] = dpa.d; // Expect: 0.0131072 } } diff --git a/tests/autodiff/reverse-loop.slang.expected.txt b/tests/autodiff/reverse-loop.slang.expected.txt new file mode 100644 index 000000000..76b7cf779 --- /dev/null +++ b/tests/autodiff/reverse-loop.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +8.000000 +0.013107 +0.000000 +0.000000 +0.000000 |
