diff options
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/core.meta.slang | 3 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 67 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 38 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 94 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 279 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 62 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 4 |
13 files changed, 466 insertions, 118 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 31dd5ed29..533713016 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2823,6 +2823,9 @@ attribute_syntax [fastopt] : FastOptAttribute; __attributeTarget(LoopStmt) attribute_syntax [allow_uav_condition] : AllowUAVConditionAttribute; +__attributeTarget(LoopStmt) +attribute_syntax [MaxIters(count)] : MaxItersAttribute; + __attributeTarget(IfStmt) attribute_syntax [flatten] : FlattenAttribute; diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 666ca77ea..42b79ca4a 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -620,6 +620,14 @@ class UnrollAttribute : public Attribute IntegerLiteralValue getCount(); }; +// An `[maxiters(count)]` +class MaxItersAttribute : public Attribute +{ + SLANG_AST_CLASS(MaxItersAttribute) + + int32_t value = 0; +}; + class LoopAttribute : public Attribute { SLANG_AST_CLASS(LoopAttribute) diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index e73f04301..9f3e79978 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -507,6 +507,17 @@ namespace Slang // as 1 arg if nothing is specified) SLANG_ASSERT(attr->args.getCount() == 1); } + else if (auto maxItersAttrs = as<MaxItersAttribute>(attr)) + { + if (auto cint = checkConstantIntVal(attr->args[0])) + { + maxItersAttrs->value = (int32_t) cint->value; + } + else + { + getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.getCount(), 1); + } + } else if (auto userDefAttr = as<UserDefinedAttribute>(attr)) { // check arguments against attribute parameters defined in attribClassDecl diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 04acad435..fca34f9a2 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -149,6 +149,7 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns builder->markInstAsDifferential(diffSub, resultType); auto diffMul = builder->emitMul(resultType, primalRight, primalRight); + builder->markInstAsPrimal(diffMul); auto diffDiv = builder->emitDiv(diffType, diffSub, diffMul); builder->markInstAsDifferential(diffDiv, resultType); @@ -881,6 +882,29 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI return InstPair(primalUpdateField, diffUpdateElement); } +List<IRInst*> ForwardDiffTranscriber::transcribePhiArgs(IRBuilder* builder, List<IRInst*> origPhiArgs) +{ + // Grab the differentials for any phi nodes. + List<IRInst*> newArgs; + for (auto origArg : origPhiArgs) + { + auto primalArg = lookupPrimalInst(builder, origArg); + newArgs.add(primalArg); + + if (differentiateType(builder, origArg->getDataType())) + { + auto diffArg = lookupDiffInst(origArg, nullptr); + if (diffArg) + newArgs.add(diffArg); + else + newArgs.add( + getDifferentialZeroOfType(builder, origArg->getDataType())); + } + } + + return newArgs; +} + InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* origLoop) { // The loop comes with three blocks.. we just need to transcribe each one @@ -902,13 +926,14 @@ InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* orig diffLoopOperands.add(diffTargetBlock); diffLoopOperands.add(diffBreakBlock); diffLoopOperands.add(diffContinueBlock); - - // If there are any other operands, use their primal versions. + + List<IRInst*> phiArgs; for (UIndex ii = diffLoopOperands.getCount(); ii < origLoop->getOperandCount(); ii++) - { - auto primalOperand = findOrTranscribePrimalInst(builder, origLoop->getOperand(ii)); - diffLoopOperands.add(primalOperand); - } + phiArgs.add(origLoop->getOperand(ii)); + + auto newPhiArgs = transcribePhiArgs(builder, phiArgs); + for (auto newArg : newPhiArgs) + diffLoopOperands.add(newArg); IRInst* diffLoop = builder->emitIntrinsicInst( nullptr, @@ -917,6 +942,9 @@ InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* orig diffLoopOperands.getBuffer()); builder->markInstAsMixedDifferential(diffLoop); + if (auto maxItersDecoration = origLoop->findDecoration<IRLoopMaxItersDecoration>()) + builder->addLoopMaxItersDecoration(diffLoop, maxItersDecoration->getMaxIters()); + return InstPair(diffLoop, diffLoop); } @@ -1211,6 +1239,28 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I return diffFunc; } +void ForwardDiffTranscriber::checkAutodiffInstDecorations(IRFunc* fwdFunc) +{ + for (auto block = fwdFunc->getFirstBlock(); block; block = block->getNextBlock()) + { + for (auto inst = block->getFirstOrdinaryInst(); inst; inst = inst->getNextInst()) + { + // TODO: Special case, not sure why these insts show up + if (as<IRUndefined>(inst)) continue; + + List<IRDecoration*> decorations; + for (auto decoration : inst->getDecorations()) + { + if (as<IRAutodiffInstDecoration>(decoration)) + decorations.add(decoration); + } + + // Must have _exactly_ one autodiff tag. + SLANG_ASSERT(decorations.getCount() == 1); + } + } +} + // Transcribe a function definition. InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) { @@ -1266,6 +1316,10 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr } } } + +#if _DEBUG + checkAutodiffInstDecorations(diffFunc); +#endif return InstPair(primalFunc, diffFunc); } @@ -1310,7 +1364,6 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_MakeMatrix: case kIROp_MakeMatrixFromScalar: case kIROp_MatrixReshape: - case kIROp_VectorReshape: case kIROp_IntCast: case kIROp_FloatCast: case kIROp_MakeVectorFromScalar: diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 260b0a433..e80b25754 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -94,6 +94,10 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase // Transcribe a function without marking the result as a decoration. IRFunc* transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc); + List<IRInst*> transcribePhiArgs(IRBuilder* builder, List<IRInst*> origPhiArgs); + + void checkAutodiffInstDecorations(IRFunc* fwdFunc); + // Create an empty func to represent the transcribed func of `origFunc`. virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 702f9819a..20090ca42 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -765,7 +765,7 @@ namespace Slang // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the // derivative of the return value. - DiffTransposePass::FuncTranspositionInfo info = { paramTransposeInfo.dOutParam, nullptr}; + DiffTransposePass::FuncTranspositionInfo info = { paramTransposeInfo.dOutParam }; diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info); eliminateDeadCode(diffPropagateFunc); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 31a3072c0..10a734d65 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -901,6 +901,15 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst { mapPrimalInst(origInst, pair.primal); mapDifferentialInst(origInst, pair.differential); + + + if (pair.primal != pair.differential && + !pair.primal->findDecoration<IRAutodiffInstDecoration>() && + !as<IRConstant>(pair.primal)) + { + builder->markInstAsPrimal(pair.primal); + } + if (pair.differential) { switch (pair.differential->getOp()) @@ -920,16 +929,27 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); } - // Tag the differential inst using a decoration (if it doesn't have one) - if (!pair.differential->findDecoration<IRDifferentialInstDecoration>() && - !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>() && - !as<IRConstant>(pair.differential)) + // Automatically tag the primal and differential results + // if they haven't already been handled by the + // code. + // + if (pair.primal != pair.differential) + { + if (!pair.differential->findDecoration<IRAutodiffInstDecoration>() + && !as<IRConstant>(pair.differential)) + { + auto primalType = as<IRType>(pair.primal->getDataType()); + builder->markInstAsDifferential(pair.differential, primalType); + } + } + else { - // TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential - // instead. - // - auto primalType = as<IRType>(pair.primal->getDataType()); - builder->markInstAsDifferential(pair.differential, primalType); + if (!pair.primal->findDecoration<IRAutodiffInstDecoration>() + && !as<IRConstant>(pair.differential)) + { + auto mixedType = as<IRType>(pair.primal->getDataType()); + builder->markInstAsMixedDifferential(pair.primal, mixedType); + } } break; diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index d9b28ea3c..2953c6206 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -78,11 +78,6 @@ struct DiffTransposePass // of the *output* of the function. // IRInst* dOutInst; - - // Mapping between *primal* insts in the forward-mode function, and the - // reverse-mode function - // - Dictionary<IRInst*, IRInst*>* primalsMap; }; struct PendingBlockTerminatorEntry @@ -353,6 +348,13 @@ struct DiffTransposePass getPhiGrads(firstLoopBlock).getBuffer()); } + auto phiGrads = getPhiGrads(condBlock); + if (phiGrads.getCount() > 0) + { + revTrueBlock = insertPhiBlockBefore(revTrueBlock, phiGrads); + revFalseBlock = insertPhiBlockBefore(revFalseBlock, phiGrads); + } + // Emit condition into the new cond block. builder.setInsertInto(revCondBlock); builder.emitIfElse( @@ -533,8 +535,6 @@ struct DiffTransposePass // firstRevDiffBlockMap[revDiffFunc] = revBlockMap[terminalDiffBlocks[0]]; - IRInst* retVal = nullptr; - for (auto block : workList) { // Set dOutParameter as the transpose gradient for the return inst, if any. @@ -543,7 +543,6 @@ struct DiffTransposePass if (auto returnInst = as<IRReturn>(block->getTerminator())) { this->addRevGradientForFwdInst(returnInst, RevGradient(returnInst, transposeInfo.dOutInst, nullptr)); - retVal = returnInst->getVal(); } } @@ -572,6 +571,11 @@ struct DiffTransposePass auto terminalPrimalBlock = terminalPrimalBlocks[0]; auto firstRevBlock = as<IRBlock>(revBlockMap[terminalDiffBlocks[0]]); + auto returnDecoration = + terminalPrimalBlock->getTerminator()->findDecoration<IRBackwardDerivativePrimalReturnDecoration>(); + SLANG_ASSERT(returnDecoration); + auto retVal = returnDecoration->getBackwardDerivativePrimalReturnValue(); + terminalPrimalBlock->getTerminator()->removeAndDeallocate(); IRBuilder subBuilder(builder.getSharedBuilder()); @@ -582,15 +586,6 @@ struct DiffTransposePass auto branch = subBuilder.emitBranch(firstRevBlock); - if (!retVal || retVal->getOp() == kIROp_VoidLit) - { - retVal = subBuilder.getVoidValue(); - } - else - { - auto makePair = cast<IRMakeDifferentialPair>(retVal); - retVal = makePair->getPrimalValue(); - } subBuilder.addBackwardDerivativePrimalReturnDecoration(branch, retVal); } @@ -610,6 +605,25 @@ struct DiffTransposePass } } + IRInst* extractAccumulatorVarGradient(IRBuilder* builder, IRInst* fwdInst) + { + if (auto accVar = getOrCreateAccumulatorVar(fwdInst)) + { + auto gradValue = builder->emitLoad(accVar); + builder->emitStore( + accVar, + emitDZeroOfDiffInstType( + builder, + tryGetPrimalTypeFromDiffInst(fwdInst))); + + return gradValue; + } + else + { + return nullptr; + } + } + // Fetch or create a gradient accumulator var // corresponding to a inst. These are used to // accumulate gradients across blocks. @@ -688,6 +702,30 @@ struct DiffTransposePass } } + // Some special instructions simply need to be copied over. + // These do not deal with differentials. + // TODO: This will not work if there are any differential + // insts that rely on loop counter vars having a specific + // value. + // The solution is to have primal insts appearing in + // differential blocks be in their own special blocks that are + // ignored entirely, rather than dealing with them one inst + // at a time. + // + for (IRInst* child = fwdBlock->getFirstChild(); child;) + { + auto nextChild = child->getNextInst(); + + if (child->findDecoration<IRLoopCounterDecoration>()) + { + // Loop counter insts should not have any gradients. + SLANG_ASSERT(!hasRevGradients(child)); + child->insertAtEnd(revBlock); + } + + child = nextChild; + } + // Move pointer & reference insts to the top of the reverse-mode block. List<IRInst*> nonValueInsts; for (IRInst* child = fwdBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) @@ -719,6 +757,7 @@ struct DiffTransposePass if (as<IRDecoration>(child) || as<IRParam>(child)) continue; + transposeInst(&builder, child); } @@ -744,10 +783,10 @@ struct DiffTransposePass // if (isInstUsedOutsideParentBlock(param)) { - auto accVar = getOrCreateAccumulatorVar(param); + auto accGradient = extractAccumulatorVarGradient(&builder, param); addRevGradientForFwdInst( param, - RevGradient(param, builder.emitLoad(accVar), nullptr)); + RevGradient(param, accGradient, nullptr)); } if (hasRevGradients(param)) @@ -839,15 +878,6 @@ struct DiffTransposePass break; } - // Some special instructions simply need to be copied over. - // These do not deal with differentials. - // - if (inst->findDecoration<IRLoopCounterDecoration>()) - { - inst->insertAtEnd(builder->getBlock()); - return; - } - // Look for gradient entries for this inst. List<RevGradient> gradients; if (hasRevGradients(inst)) @@ -898,9 +928,9 @@ struct DiffTransposePass // if (isInstUsedOutsideParentBlock(inst) && !as<IRLoad>(inst)) { - auto accVar = getOrCreateAccumulatorVar(inst); + auto accGradient = extractAccumulatorVarGradient(builder, inst); gradients.add( - RevGradient(inst, builder->emitLoad(accVar), nullptr)); + RevGradient(inst, accGradient, nullptr)); } // Emit the aggregate of all the gradients here. @@ -2399,8 +2429,6 @@ struct DiffTransposePass Dictionary<IRInst*, IRVar*> revAccumulatorVarMap; - Dictionary<IRInst*, IRInst*>* primalsMap; - List<IRInst*> usedPtrs; Dictionary<IRBlock*, IRBlock*> revBlockMap; @@ -2412,6 +2440,8 @@ struct DiffTransposePass List<PendingBlockTerminatorEntry> pendingBlocks; Dictionary<IRBlock*, List<IRInst*>> phiGradsMap; + + Dictionary<IRBlock*, IRBlock*> initializerBlockMap; }; diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 057ff53c4..1a85ea6a4 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -232,6 +232,24 @@ struct DiffUnzipPass // lowerIndexedRegions(); + // Copy regions from fwd-block to their split blocks + // to make it easier to do lookups. + // + { + List<IRBlock*> workList; + for (auto blockRegionPair : indexRegionMap) + { + IRBlock* block = blockRegionPair.Key; + workList.add(block); + } + + for (auto block : workList) + { + indexRegionMap[as<IRBlock>(primalMap[block])] = (IndexedRegion*)indexRegionMap[block]; + indexRegionMap[as<IRBlock>(diffMap[block])] = (IndexedRegion*)indexRegionMap[block]; + } + } + // Process intermediate insts in indexed blocks // into array loads/stores. // @@ -262,19 +280,44 @@ struct DiffUnzipPass IRBlock* getUpdateBlock(IndexedRegion* region) { + // TODO: What if the 'continue' region has multiple + // blocks? + // We ideally want the _last_ block before control loops back. + // + SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>( + region->continueBlock->getTerminator())->getTargetBlock() == region->firstBlock); + return region->continueBlock; } + + IRBlock* getFirstLoopBodyBlock(IndexedRegion* region) + { + // Grab the 'condition' block. + auto condBlock = region->firstBlock; + + SLANG_RELEASE_ASSERT(as<IRIfElse>(condBlock->getTerminator())); + + return as<IRIfElse>(condBlock->getTerminator())->getTrueBlock(); + } void tryInferMaxIndex(IndexedRegion* region) { if (region->status != IndexedRegion::CountStatus::Unresolved) return; + + auto loop = as<IRLoop>(region->initBlock->getTerminator()); - // We're going to fix this at a some random number - // for now, and then add some basic inference + user-defined decoration - // - region->maxIters = 5; - region->status = IndexedRegion::CountStatus::Static; + if (auto maxItersDecoration = loop->findDecoration<IRLoopMaxItersDecoration>()) + { + region->maxIters = (Count) maxItersDecoration->getMaxIters(); + region->status = IndexedRegion::CountStatus::Static; + } + + if (region->status == IndexedRegion::CountStatus::Unresolved) + { + SLANG_UNEXPECTED("Could not resolve max iters \ + for loop appearing in reverse-mode"); + } } // Make a primal value *available* to the differential block. @@ -297,22 +340,49 @@ struct DiffUnzipPass for (auto region : indexRegions) { - IRBlock* initializerBlock = getInitializerBlock(region); + //IRBlock* initializerBlock = getInitializerBlock(region); + IRBlock* breakBlock = region->breakBlock; // Grab first primal block. - auto firstPrimalBlock = primalMap[region->breakBlock->getParent()->getFirstBlock()->getNextBlock()]; + IRBlock* firstPrimalBlock = as<IRBlock>(primalMap[region->breakBlock->getParent()->getFirstBlock()->getNextBlock()]); // Make variable in the top-most block (so it's visible to diff blocks) - builder.setInsertInto(firstPrimalBlock); - region->primalCountVar = builder.emitVar(builder.getUIntType()); - - // Make another variable in the diff block initialized to the - // final value of the primal counter. + builder.setInsertBefore(firstPrimalBlock->getTerminator()); + region->primalCountVar = builder.emitVar(builder.getIntType()); + builder.emitStore( + region->primalCountVar, + builder.getIntValue(builder.getIntType(), 0)); + + // NOTE: This is a hacky shortcut we're taking here. + // Technically the unzip pass should not affect the + // correctness (it must still compute the proper fwd-mode derivative) + // However, we're currently making the loop counter go backwards to + // make it easier on the transposition pass, so the output from + // the unzip pass is neither fwd-mode or rev-mode until the transposition + // step is complete. + // + // TODO: Ideally this needs to be replaced with a small inversion step + // within the transposition pass. + // + // Emit the diff counter into the diff *break* block ( + // which we're praying turns into the reverse initializer block) + // initialized to the final value of the primal counter. // - builder.setInsertInto(diffMap[initializerBlock]); - auto primalCounterValue = builder.emitLoad(region->primalCountVar); - region->diffCountVar = builder.emitVar(builder.getUIntType()); - builder.emitStore(region->diffCountVar, primalCounterValue); + builder.setInsertBefore(as<IRBlock>(diffMap[breakBlock])->getTerminator()); + //auto primalCounterValue = builder.emitLoad(region->primalCountVar); + auto primalCounterCurrValue = builder.emitLoad(region->primalCountVar); + auto primalCounterLastValue = builder.emitSub( + primalCounterCurrValue->getDataType(), + primalCounterCurrValue, + builder.getIntValue(builder.getIntType(), 1)); + + region->diffCountVar = builder.emitVar(builder.getIntType()); + auto diffCountInit = builder.emitStore(region->diffCountVar, primalCounterLastValue); + + builder.addLoopCounterDecoration(diffCountInit); + builder.addLoopCounterDecoration(region->diffCountVar); + builder.addLoopCounterDecoration(primalCounterCurrValue); + builder.addLoopCounterDecoration(primalCounterLastValue); IRBlock* updateBlock = getUpdateBlock(region); @@ -324,9 +394,9 @@ struct DiffUnzipPass auto counterVal = builder.emitLoad(region->primalCountVar); auto incCounterVal = builder.emitAdd( - builder.getUIntType(), + builder.getIntType(), counterVal, - builder.getIntValue(builder.getUIntType(), 1)); + builder.getIntValue(builder.getIntType(), 1)); auto incStore = builder.emitStore(region->primalCountVar, incCounterVal); @@ -336,25 +406,16 @@ struct DiffUnzipPass } { - // NOTE: This is a hacky shortcut we're taking here. - // Technically the unzip pass should not affect the - // correctness (it must still compute the proper fwd-mode derivative) - // However, we're currently making the loop counter go backwards to - // make it easier on the transposition pass, so the output from - // the unzip pass is neither fwd-mode or rev-mode until the transposition - // step is complete. - // - // TODO: Ideally this needs to be replaced with a small inversion step - // within the transposition pass. - // + IRBlock* firstLoopBlock = getFirstLoopBodyBlock(region); + auto diffFirstLoopBlock = as<IRBlock>(diffMap[firstLoopBlock]); - builder.setInsertBefore(as<IRBlock>(diffMap[updateBlock])->getTerminator()); + builder.setInsertBefore(diffFirstLoopBlock->getTerminator()); auto counterVal = builder.emitLoad(region->diffCountVar); auto decCounterVal = builder.emitSub( - builder.getUIntType(), + builder.getIntType(), counterVal, - builder.getIntValue(builder.getUIntType(), 0)); + builder.getIntValue(builder.getIntType(), 1)); auto decStore = builder.emitStore(region->diffCountVar, decCounterVal); @@ -363,6 +424,27 @@ struct DiffUnzipPass builder.addLoopCounterDecoration(counterVal); builder.addLoopCounterDecoration(decCounterVal); builder.addLoopCounterDecoration(decStore); + + // TODO: + // This is another hack here to avoid the counter from going negative + // (since they are not valid indices) + // + IRBlock* diffCondBlock = as<IRBlock>(diffMap[region->firstBlock]); + + builder.setInsertBefore(diffCondBlock->getTerminator()); + IRInst* diffCounterVal = builder.emitLoad(region->diffCountVar); + IRInst* diffCounterCmp = builder.emitIntrinsicInst( + builder.getBoolType(), + kIROp_Geq, + 2, + List<IRInst*>( + diffCounterVal, + builder.getIntValue(builder.getIntType(), 0)).getBuffer()); + + as<IRIfElse>(diffCondBlock->getTerminator())->condition.set(diffCounterCmp); + + builder.addLoopCounterDecoration(diffCounterVal); + builder.addLoopCounterDecoration(diffCounterCmp); } } @@ -394,6 +476,7 @@ struct DiffUnzipPass for (; region; region = region->parent) regions.add(region); } + for (auto inst : primalInsts) { @@ -407,6 +490,7 @@ struct DiffUnzipPass if (isDifferentialInst(useBlock)) { shouldStore = true; + break; } } @@ -439,52 +523,111 @@ struct DiffUnzipPass auto storageVar = builder.emitVar(arrayType); // 3. Store current value into the array and replace uses with a load. + // TODO: If an index is missing, use the 'last' value of the primal index. { builder.setInsertAfter(inst); IRInst* storeAddr = storageVar; - IRType* currType = storageVar->getDataType(); + IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType(); for (auto region : regions) { currType = as<IRArrayType>(currType)->getElementType(); storeAddr = builder.emitElementAddress( - currType, + builder.getPtrType(currType), storeAddr, - region->primalCountVar); + builder.emitLoad(region->primalCountVar)); } builder.emitStore(storeAddr, inst); } // 4. Replace uses in differential blocks with loads from the array. + List<IRInst*> instsToTag; { + List<IRUse*> diffUses; for (auto use = inst->firstUse; use; use = use->nextUse) + { + if (as<IRDecoration>(use->getUser())) + continue; + + IRBlock* useBlock = as<IRBlock>(use->getUser()->getParent()); + if (useBlock && isDifferentialInst(useBlock)) + diffUses.add(use); + } + + for (auto use : diffUses) { IRBlock* useBlock = as<IRBlock>(use->getUser()->getParent()); + builder.setInsertBefore(use->getUser()); + + IRInst* loadAddr = storageVar; + IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType(); - if (isDifferentialInst(useBlock)) + // Enumerate use block regions. + // TODO: Probably a good idea to do this ahead of time for + // all blocks. + // + List<IndexedRegion*> useBlockRegions; { - builder.setInsertBefore(use->getUser()); + IndexedRegion* region = indexRegionMap.ContainsKey(useBlock) ? + (IndexedRegion*)indexRegionMap[useBlock] : nullptr; + for (; region; region = region->parent) + useBlockRegions.add(region); + } - IRInst* loadAddr = storageVar; - IRType* currType = storageVar->getDataType(); + for (auto region : regions) + { + currType = as<IRArrayType>(currType)->getElementType(); + if (useBlockRegions.contains(region)) + { + // If the use-block is under the same region, use the + // differential counter variable + // + auto diffCounterCurrValue = builder.emitLoad(region->diffCountVar); + instsToTag.add(diffCounterCurrValue); - for (auto region : regions) + loadAddr = builder.emitElementAddress( + builder.getPtrType(currType), + loadAddr, + diffCounterCurrValue); + } + else { - currType = as<IRArrayType>(currType)->getElementType(); + // If the use-block is outside this region, use the + // last available value (by indexing with primal counter minus 1) + // + auto primalCounterCurrValue = builder.emitLoad(region->primalCountVar); + auto primalCounterLastValue = builder.emitSub( + primalCounterCurrValue->getDataType(), + primalCounterCurrValue, + builder.getIntValue(builder.getIntType(), 1)); + + instsToTag.add(primalCounterCurrValue); + instsToTag.add(primalCounterLastValue); loadAddr = builder.emitElementAddress( - currType, + builder.getPtrType(currType), loadAddr, - region->diffCountVar); + primalCounterLastValue); } - use->set(builder.emitLoad(loadAddr)); + instsToTag.add(loadAddr); } + + auto loadedValue = builder.emitLoad(loadAddr); + instsToTag.add(loadedValue); + + use->set(loadedValue); } } + + // TODO: Loop-counter is not really the right decoration.. + // replace with primal-inst when it's ready. + // + for (auto instToTag : instsToTag) + builder.addLoopCounterDecoration(instToTag); } } @@ -710,7 +853,11 @@ struct DiffUnzipPass // Check that we have an unambiguous 'first' differential block. SLANG_ASSERT(firstDiffBlock); + auto primalBranch = primalBuilder->emitBranch(firstDiffBlock); + primalBuilder->addBackwardDerivativePrimalReturnDecoration( + primalBranch, lookupPrimalInst(mixedReturn->getVal())); + auto pairVal = diffBuilder->emitMakeDifferentialPair( pairType, lookupPrimalInst(mixedReturn->getVal()), @@ -726,6 +873,9 @@ struct DiffUnzipPass { // If return value is not differentiable, just turn it into a trivial branch. auto primalBranch = primalBuilder->emitBranch(firstDiffBlock); + primalBuilder->addBackwardDerivativePrimalReturnDecoration( + primalBranch, primalBuilder->getVoidValue()); + auto returnInst = diffBuilder->emitReturn(); diffBuilder->markInstAsDifferential(returnInst, nullptr); return InstPair(primalBranch, returnInst); @@ -903,15 +1053,38 @@ struct DiffUnzipPass // Push a new index. addNewIndex(mixedLoop); - return InstPair( - primalBuilder->emitLoop( - as<IRBlock>(primalMap[nextBlock]), - as<IRBlock>(primalMap[breakBlock]), - as<IRBlock>(primalMap[continueBlock])), - diffBuilder->emitLoop( - as<IRBlock>(diffMap[nextBlock]), - as<IRBlock>(diffMap[breakBlock]), - as<IRBlock>(diffMap[continueBlock]))); + // Split args. + List<IRInst*> primalArgs; + List<IRInst*> diffArgs; + for (UIndex ii = 0; ii < mixedLoop->getArgCount(); ii++) + { + if (isDifferentialInst(mixedLoop->getArg(ii))) + diffArgs.add(mixedLoop->getArg(ii)); + else + primalArgs.add(mixedLoop->getArg(ii)); + } + + auto primalLoop = primalBuilder->emitLoop( + as<IRBlock>(primalMap[nextBlock]), + as<IRBlock>(primalMap[breakBlock]), + as<IRBlock>(primalMap[continueBlock]), + primalArgs.getCount(), + primalArgs.getBuffer()); + + auto diffLoop = diffBuilder->emitLoop( + as<IRBlock>(diffMap[nextBlock]), + as<IRBlock>(diffMap[breakBlock]), + as<IRBlock>(diffMap[continueBlock]), + diffArgs.getCount(), + diffArgs.getBuffer()); + + if (auto maxItersDecoration = mixedLoop->findDecoration<IRLoopMaxItersDecoration>()) + { + primalBuilder->addLoopMaxItersDecoration(primalLoop, maxItersDecoration->getMaxIters()); + diffBuilder->addLoopMaxItersDecoration(diffLoop, maxItersDecoration->getMaxIters()); + } + + return InstPair(primalLoop, diffLoop); } InstPair splitControlFlow(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* branchInst) diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index f38bdfdbd..2ce5a48f7 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -460,6 +460,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_ForwardDerivativeDecoration: case kIROp_DerivativeMemberDecoration: case kIROp_DifferentiableTypeDictionaryDecoration: + case kIROp_PrimalInstDecoration: case kIROp_DifferentialInstDecoration: case kIROp_MixedDifferentialInstDecoration: case kIROp_BackwardDerivativeDecoration: diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 26a92a17a..e627c575d 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -598,6 +598,7 @@ INST(GetOptiXSbtDataPtr, getOptiXSbtDataPointer, 0, 0) INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(LayoutDecoration, layout, 1, 0) INST(LoopControlDecoration, loopControl, 1, 0) + INST(LoopMaxItersDecoration, loopMaxIters, 1, 0) INST(IntrinsicOpDecoration, intrinsicOp, 1, 0) /* TargetSpecificDecoration */ INST(TargetDecoration, target, 1, 0) @@ -767,13 +768,19 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(LoopCounterDecoration, loopCounterDecoration, 0, 0) + /* Auto-diff inst decorations */ + /// Used by the auto-diff pass to mark insts that compute + /// a primal value. + INST(PrimalInstDecoration, primalInstDecoration, 0, 0) + /// Used by the auto-diff pass to mark insts that compute /// a differential value. - INST(DifferentialInstDecoration, diffInstDecoration, 1, 0) + INST(DifferentialInstDecoration, diffInstDecoration, 1, 0) /// Used by the auto-diff pass to mark insts that compute /// BOTH a differential and a primal value. - INST(MixedDifferentialInstDecoration, mixedDiffInstDecoration, 1, 0) + INST(MixedDifferentialInstDecoration, mixedDiffInstDecoration, 1, 0) + INST_RANGE(AutodiffInstDecoration, PrimalInstDecoration, MixedDifferentialInstDecoration) /// Used by the auto-diff pass to mark insts whose result is stored /// in an intermediary struct for reuse in backward propagation phase. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index fad20e900..2453b56a7 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -66,6 +66,14 @@ struct IRLoopControlDecoration : IRDecoration } }; +struct IRLoopMaxItersDecoration : IRDecoration +{ + enum { kOp = kIROp_LoopMaxItersDecoration }; + IR_LEAF_ISA(LoopMaxItersDecoration) + + IRConstant* getMaxItersInst() { return cast<IRConstant>(getOperand(0)); } + IRIntegerValue getMaxIters() { return as<IRIntLit>(getOperand(0))->getValue(); } +}; struct IRTargetSpecificDecoration : IRDecoration { @@ -672,7 +680,12 @@ struct IRLoopCounterDecoration : IRDecoration IR_LEAF_ISA(LoopCounterDecoration) }; -struct IRDifferentialInstDecoration : IRDecoration +struct IRAutodiffInstDecoration : IRDecoration +{ + IR_PARENT_ISA(AutodiffInstDecoration) +}; + +struct IRDifferentialInstDecoration : IRAutodiffInstDecoration { enum { @@ -686,41 +699,52 @@ struct IRDifferentialInstDecoration : IRDecoration IRInst* getPrimalInst() { return as<IRInst>(getOperand(1)); } }; -struct IRPrimalValueStructKeyDecoration : IRDecoration +struct IRPrimalInstDecoration : IRAutodiffInstDecoration { enum { - kOp = kIROp_PrimalValueStructKeyDecoration + kOp = kIROp_PrimalInstDecoration }; - IR_LEAF_ISA(PrimalValueStructKeyDecoration) + IR_LEAF_ISA(PrimalInstDecoration) +}; - IRStructKey* getStructKey() { return as<IRStructKey>(getOperand(0)); } + +struct IRMixedDifferentialInstDecoration : IRAutodiffInstDecoration +{ + enum + { + kOp = kIROp_MixedDifferentialInstDecoration + }; + + IRUse pairType; + IR_LEAF_ISA(MixedDifferentialInstDecoration) + + IRType* getPairType() { return as<IRType>(getOperand(0)); } }; -struct IRPrimalElementTypeDecoration : IRDecoration +struct IRPrimalValueStructKeyDecoration : IRDecoration { enum { - kOp = kIROp_PrimalElementTypeDecoration + kOp = kIROp_PrimalValueStructKeyDecoration }; - IR_LEAF_ISA(PrimalElementTypeDecoration) + IR_LEAF_ISA(PrimalValueStructKeyDecoration) - IRInst* getPrimalElementType() { return getOperand(0); } + IRStructKey* getStructKey() { return as<IRStructKey>(getOperand(0)); } }; -struct IRMixedDifferentialInstDecoration : IRDecoration +struct IRPrimalElementTypeDecoration : IRDecoration { enum { - kOp = kIROp_MixedDifferentialInstDecoration + kOp = kIROp_PrimalElementTypeDecoration }; - IRUse pairType; - IR_LEAF_ISA(MixedDifferentialInstDecoration) + IR_LEAF_ISA(PrimalElementTypeDecoration) - IRType* getPairType() { return as<IRType>(getOperand(0)); } + IRInst* getPrimalElementType() { return getOperand(0); } }; struct IRBackwardDifferentiableDecoration : IRDecoration @@ -3519,6 +3543,11 @@ public: addDecoration(value, kIROp_LoopControlDecoration, getIntValue(getIntType(), IRIntegerValue(mode))); } + void addLoopMaxItersDecoration(IRInst* value, IntegerLiteralValue iters) + { + addDecoration(value, kIROp_LoopMaxItersDecoration, getIntValue(getIntType(), iters)); + } + void addSemanticDecoration(IRInst* value, UnownedStringSlice const& text, int index = 0) { addDecoration(value, kIROp_SemanticDecoration, getStringValue(text), getIntValue(getIntType(), index)); @@ -3651,6 +3680,11 @@ public: addDecoration(value, kIROp_LoopCounterDecoration); } + void markInstAsPrimal(IRInst* value) + { + addDecoration(value, kIROp_PrimalInstDecoration); + } + void markInstAsDifferential(IRInst* value) { addDecoration(value, kIROp_DifferentialInstDecoration, nullptr); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 8377246fb..74f06557d 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4845,6 +4845,10 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> { getBuilder()->addLoopControlDecoration(inst, kIRLoopControl_Loop); } + else if( auto maxItersAttr = stmt->findModifier<MaxItersAttribute>() ) + { + getBuilder()->addLoopMaxItersDecoration(inst, maxItersAttr->value); + } // TODO: handle other cases here } |
