diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-02-22 22:22:26 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-22 19:22:26 -0800 |
| commit | e8c08e7ecb1124f115a1d1042277776193122b57 (patch) | |
| tree | 9c1d970c8be244aa4a32762e1de3338507d24444 /source | |
| parent | 6eb0b4dea4da1fc21767c86cc0837d0c8b68063b (diff) | |
Fixed hoisting of intermediate array & context vars (#2674)
Also added legalization for loops
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-cfg-norm.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 113 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 143 |
5 files changed, 201 insertions, 66 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index d64c6d1f6..9116f67e9 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -409,6 +409,7 @@ struct CFGNormalizationPass // false -> atleast one break statement hit. // info.breakVar = builder.emitVar(builder.getBoolType()); + builder.addNameHintDecoration(info.breakVar, UnownedStringSlice("_bflag")); builder.emitStore(info.breakVar, builder.getBoolValue(true)); // If the loop is trivial (i.e. single iteration, with no diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index 94bc1ef81..86a6f2846 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -32,8 +32,8 @@ struct ParameterBlockTransposeInfo // The value with which a primal specific parameter should be replaced in propagate func. OrderedDictionary<IRInst*, IRInst*> mapPrimalSpecificParamToReplacementInPropFunc; - // The insts added that is specific for propagate functions and should be removed + // from the future primal func. List<IRInst*> propagateFuncSpecificPrimalInsts; diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 95ad58586..a4c79d09a 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1138,7 +1138,11 @@ struct DiffTransposePass IRInst* hoistPrimalInst(IRBuilder* revBuilder, IRInst* inst) { - SLANG_RELEASE_ASSERT(isPrimalInst(inst)); + if (as<IRBlock>(inst->getParent()) && + isDifferentialInst(as<IRBlock>(inst->getParent()))) + { + SLANG_RELEASE_ASSERT(isPrimalInst(inst)); + } // Are the operands of this primal inst also available in the reverse-mode context? // If not, move/load them. @@ -1379,7 +1383,7 @@ struct DiffTransposePass // In order to perform the call, we need a temporary var to store the DiffPair. auto pairType = as<IRPtrTypeBase>(arg->getDataType())->getValueType(); auto tempVar = builder->emitVar(pairType); - auto primalVal = builder->emitLoad(instPair->getPrimal()); + auto primalVal = builder->emitLoad(hoistPrimalInst(builder, instPair->getPrimal())); auto diffVal = builder->emitLoad(instPair->getDiff()); auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal); builder->emitStore(tempVar, pairVal); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 50c5c4ea6..096751836 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -130,6 +130,91 @@ struct ExtractPrimalFuncContext } } + bool doesInstHaveDiffUse(IRInst* inst) + { + bool hasDiffUser = false; + + for (auto use = inst->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (isDiffInst(user)) + { + // Ignore uses that is a return or MakeDiffPair + switch (user->getOp()) + { + case kIROp_Return: + continue; + case kIROp_MakeDifferentialPair: + if (!user->hasMoreThanOneUse() && user->firstUse && + user->firstUse->getUser()->getOp() == kIROp_Return) + continue; + break; + default: + break; + } + hasDiffUser = true; + break; + } + } + + return hasDiffUser; + } + + bool doesInstHaveStore(IRInst* inst) + { + SLANG_RELEASE_ASSERT(as<IRPtrTypeBase>(inst->getDataType())); + + for (auto use = inst->firstUse; use; use = use->nextUse) + { + if (as<IRStore>(use->getUser())) + return true; + + if (as<IRPtrTypeBase>(use->getUser()->getDataType())) + { + if (doesInstHaveStore(use->getUser())) + return true; + } + } + + return false; + } + + bool isIntermediateContextType(IRType* type) + { + switch (type->getOp()) + { + case kIROp_BackwardDiffIntermediateContextType: + return true; + case kIROp_PtrType: + return isIntermediateContextType(as<IRPtrTypeBase>(type)->getValueType()); + case kIROp_ArrayType: + return isIntermediateContextType(as<IRArrayType>(type)->getElementType()); + } + + return false; + } + + bool shouldStoreVar(IRVar* var) + { + // Always store intermediate context var. + if (var->findDecoration<IRBackwardDerivativePrimalContextDecoration>()) + { + return true; + } + + if (isIntermediateContextType(var->getDataType())) + { + return true; + } + + // For now the store policy is simple, we use two conditions: + // 1. Is the var used in a differential block and, + // 2. Does the var have a store + // + + return (doesInstHaveDiffUse(var) && doesInstHaveStore(var)); + } + bool shouldStoreInst(IRInst* inst) { if (!inst->getDataType()) @@ -181,29 +266,7 @@ struct ExtractPrimalFuncContext } // Only store if the inst has differential inst user. - bool hasDiffUser = false; - for (auto use = inst->firstUse; use; use = use->nextUse) - { - auto user = use->getUser(); - if (isDiffInst(user)) - { - // Ignore uses that is a return or MakeDiffPair - switch (user->getOp()) - { - case kIROp_Return: - continue; - case kIROp_MakeDifferentialPair: - if (!user->hasMoreThanOneUse() && user->firstUse && - user->firstUse->getUser()->getOp() == kIROp_Return) - continue; - break; - default: - break; - } - hasDiffUser = true; - break; - } - } + bool hasDiffUser = doesInstHaveDiffUse(inst); if (!hasDiffUser) return false; @@ -303,8 +366,7 @@ struct ExtractPrimalFuncContext } else if (inst->getOp() == kIROp_Var) { - // Always store intermediate context var. - if (inst->findDecoration<IRBackwardDerivativePrimalContextDecoration>()) + if (shouldStoreVar(as<IRVar>(inst))) { auto field = addIntermediateContextField(cast<IRPtrTypeBase>(inst->getDataType())->getValueType(), outIntermediary); builder.setInsertBefore(inst); @@ -313,6 +375,7 @@ struct ExtractPrimalFuncContext inst->replaceUsesWith(fieldAddr); builder.addPrimalValueStructKeyDecoration(inst, field->getKey()); } + } } } diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index e2c84ce8b..2ebc330f0 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -414,6 +414,7 @@ struct DiffUnzipPass // Make variable in the top-most block (so it's visible to diff blocks) region->primalCountLastVar = builder.emitVar(builder.getIntType()); + builder.addNameHintDecoration(region->primalCountLastVar, UnownedStringSlice("_pc_last_var")); { IRBlock* primalInitBlock = as<IRBlock>(primalMap[region->initBlock]); @@ -432,6 +433,7 @@ struct DiffUnzipPass primalCondBlock, builder.getIntType(), phiCounterArgLoopEntryIndex); + builder.addNameHintDecoration(region->primalCountParam, UnownedStringSlice("_pc")); builder.addLoopCounterDecoration(region->primalCountParam); builder.markInstAsPrimal(region->primalCountParam); @@ -471,6 +473,7 @@ struct DiffUnzipPass diffCondBlock, builder.getIntType(), phiCounterArgLoopEntryIndex); + builder.addNameHintDecoration(region->diffCountParam, UnownedStringSlice("_dc")); builder.addLoopCounterDecoration(region->diffCountParam); builder.markInstAsPrimal(region->diffCountParam); @@ -535,33 +538,11 @@ struct DiffUnzipPass as<IRFuncType>(child->getDataType()) || as<IRTypeKind>(child->getDataType())) continue; - - // We also don't care about pointer types (only Loads) - if (auto ptrType = as<IRPtrTypeBase>(child->getDataType())) - { - // There's an exception to this, if the var is an intermediate context type - // variable since there won't be a load from this yet (the load will - // be inserted later during the transposition process) - // - if (as<IRBackwardDiffIntermediateContextType>(ptrType->getValueType())) - primalInsts.add(child); - - continue; - } primalInsts.add(child); } IRBuilder builder(autodiffContext->moduleInst->getModule()); - - // Build list of indices that this block is affected by. - List<IndexedRegion*> regions; - { - IndexedRegion* region = indexRegionMap[fwdBlock]; - for (; region; region = region->parent) - regions.add(region); - } - for (auto inst : primalInsts) { @@ -581,43 +562,115 @@ struct DiffUnzipPass if (!shouldStore) continue; - // 2. Emit an array to top-level to allocate space. - - builder.setInsertBefore(firstPrimalBlock->getTerminator()); + // 2. If we're dealing with a var, we need to locate the value that + // we actually need to store. We assume everything is SSA form + // so there must be a single IRStore on this var. + // + IRInst* valueToStore = nullptr; + IRBlock* valueBlock = nullptr; + IRType* valueType = nullptr; - IRType* arrayType = inst->getDataType(); bool isPtrType = false; + bool isIntermediateContext = false; - if (auto ptrType = as<IRPtrTypeBase>(arrayType)) + if (auto ptrValueType = as<IRPtrTypeBase>(inst->getDataType())) { - SLANG_RELEASE_ASSERT(as<IRBackwardDiffIntermediateContextType>(ptrType->getValueType())); - arrayType = ptrType->getValueType(); isPtrType = true; + + // Find value to store + for (auto use = inst->firstUse; use; use = use->nextUse) + { + if (auto storeInst = as<IRStore>(use->getUser())) + { + // Should not see more than one IRStore + SLANG_RELEASE_ASSERT(!valueToStore); + valueToStore = storeInst->getVal(); + + // Is this the right block to use to determine if the + // store can have multiple values based on the index? + // + valueBlock = as<IRBlock>(storeInst->getParent()); + } + } + + if (as<IRBackwardDiffIntermediateContextType>(ptrValueType->getValueType())) + { + isIntermediateContext = true; + + // TODO: This should be the parent block of the `call` associated + // with this context type. The var itself _could_ be in a different place. + // + valueBlock = as<IRBlock>(inst->getParent()); + } + + valueType = ptrValueType->getValueType(); + } + else + { + isPtrType = false; + valueToStore = inst; + valueBlock = as<IRBlock>(inst->getParent()); + valueType = inst->getDataType(); } + // What do we do for primal vars that are used in the diff block + // but do not have an IRStore on them? This can happen for 'out' + // primal variables. + // + if (!valueToStore && !isIntermediateContext) + { + // For now, we can ignore them since they are used as inputs + // to 'out' parameters. If their value is every actually used, + // we will see an IRLoad which will be hoisted accordingly. + // + continue; + } + + // Build list of indices that the value's block is affected by. + List<IndexedRegion*> regions; + { + IndexedRegion* region = indexRegionMap[valueBlock]; + for (; region; region = region->parent) + regions.add(region); + } + + // 3. Emit an array to top-level to allocate space. + + builder.setInsertBefore(firstPrimalBlock->getTerminator()); + + IRType* storageType = valueType; + for (auto region : regions) { SLANG_ASSERT(region->status == IndexedRegion::CountStatus::Static); SLANG_ASSERT(region->maxIters >= 0); - arrayType = builder.getArrayType( - arrayType, + storageType = builder.getArrayType( + storageType, builder.getIntValue( builder.getUIntType(), region->maxIters + 1)); } - // Reverse the list since the indices needs to be + // Reverse the list since the indices need to be // emitted in reverse order. // regions.reverse(); - auto storageVar = builder.emitVar(arrayType); + auto storageVar = builder.emitVar(storageType); + if (isIntermediateContext) + builder.addBackwardDerivativePrimalContextDecoration( + storageVar, + storageVar); - // 3. Store current value into the array and replace uses with a load. + // 4. Store current value into the array and replace uses with a load. // TODO: If an index is missing, use the 'last' value of the primal index. + { - setInsertAfterOrdinaryInst(&builder, inst); + if (!isIntermediateContext) + setInsertAfterOrdinaryInst(&builder, valueToStore); + else + setInsertAfterOrdinaryInst(&builder, inst); IRInst* storeAddr = storageVar; IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType(); @@ -631,11 +684,25 @@ struct DiffUnzipPass storeAddr, region->primalCountParam); } - - builder.emitStore(storeAddr, inst); + + if (!isIntermediateContext) + builder.emitStore(storeAddr, valueToStore); + else + { + List<IRUse*> primalUses; + for (auto use = inst->firstUse; use; use = use->nextUse) + { + if (!isDifferentialInst(getBlock(use->getUser()))) + primalUses.add(use); + } + + for (auto use : primalUses) + use->set(storeAddr); + } } + - // 4. Replace uses in differential blocks with loads from the array. + // 5. Replace uses in differential blocks with loads from the array. List<IRInst*> instsToTag; { List<IRUse*> diffUses; |
