diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-04-28 16:22:12 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-28 13:22:12 -0700 |
| commit | 2492ec59fb52c15d1658ab32f473521b40664168 (patch) | |
| tree | 7569739846b16bda35577f6c33e37f46f99abfd4 /source | |
| parent | b07f4effda2f87ac9b3229e588121d224fd8cf52 (diff) | |
Fix handling of `[PreferRecompute]`. (#2855)
Co-authored-by: Yong He <yhe@nvidia.com>
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-doc-ast.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.cpp | 131 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 6 |
4 files changed, 42 insertions, 98 deletions
diff --git a/source/slang/slang-doc-ast.cpp b/source/slang/slang-doc-ast.cpp index 4f0d310bf..dfe3b321c 100644 --- a/source/slang/slang-doc-ast.cpp +++ b/source/slang/slang-doc-ast.cpp @@ -57,8 +57,6 @@ static void _addDeclRec(Decl* decl, List<Decl*>& outDecls) // If we don't have a loc, we have no way of locating documentation. if (decl->loc.isValid() || decl->nameAndLoc.loc.isValid()) { - if (as<AttributeDecl>(decl)) - printf("dd"); outDecls.add(decl); } diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index dd7472067..202660682 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -390,7 +390,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( { if (auto var = as<IRVar>(result.instToRecompute)) { - IRUse* storeUse = findUniqueStoredVal(var); + IRUse* storeUse = findLatestUniqueWriteUse(var); if (storeUse) workList.add(storeUse); } @@ -1300,101 +1300,6 @@ void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func) return; } -static bool doesInstHaveDiffUse(IRInst* inst) -{ - bool hasDiffUser = false; - - for (auto use = inst->firstUse; use; use = use->nextUse) - { - auto user = use->getUser(); - if (isDiffInst(user)) - { - // Ignore uses that is a return or MakeDiffPair - switch (user->getOp()) - { - case kIROp_Return: - continue; - case kIROp_MakeDifferentialPair: - if (!user->hasMoreThanOneUse() && user->firstUse && - user->firstUse->getUser()->getOp() == kIROp_Return) - continue; - break; - default: - break; - } - hasDiffUser = true; - break; - } - } - - return hasDiffUser; -} - -static bool doesInstHaveStore(IRInst* inst) -{ - SLANG_RELEASE_ASSERT(as<IRPtrTypeBase>(inst->getDataType())); - - for (auto use = inst->firstUse; use; use = use->nextUse) - { - if (as<IRStore>(use->getUser())) - return true; - - if (as<IRPtrTypeBase>(use->getUser()->getDataType())) - { - if (doesInstHaveStore(use->getUser())) - return true; - } - } - - return false; -} - -static bool isIntermediateContextType(IRType* type) -{ - switch (type->getOp()) - { - case kIROp_BackwardDiffIntermediateContextType: - return true; - case kIROp_PtrType: - return isIntermediateContextType(as<IRPtrTypeBase>(type)->getValueType()); - case kIROp_ArrayType: - return isIntermediateContextType(as<IRArrayType>(type)->getElementType()); - } - - return false; -} - -static bool shouldStoreVar(IRVar* var) -{ - // Always store intermediate context var. - if (const auto typeDecor = var->findDecoration<IRBackwardDerivativePrimalContextDecoration>()) - { - // If we are specializing a callee's intermediate context with types that can't be stored, - // we can't store the entire context. - if (auto spec = as<IRSpecialize>(as<IRPtrTypeBase>(var->getDataType())->getValueType())) - { - for (UInt i = 0; i < spec->getArgCount(); i++) - { - if (!canTypeBeStored(spec->getArg(i))) - return false; - } - } - return true; - } - - if (isIntermediateContextType(var->getDataType())) - { - return true; - } - - // For now the store policy is simple, we use two conditions: - // 1. Is the var used in a differential block and, - // 2. Does the var have a store - // - - return (doesInstHaveDiffUse(var) && doesInstHaveStore(var) && canTypeBeStored(as<IRPtrTypeBase>(var->getDataType())->getValueType())); -} - enum CheckpointPreference { None, @@ -1541,6 +1446,40 @@ static bool shouldStoreInst(IRInst* inst) return true; } +static bool shouldStoreVar(IRVar* var) +{ + if (const auto typeDecor = var->findDecoration<IRBackwardDerivativePrimalContextDecoration>()) + { + // If we are specializing a callee's intermediate context with types that can't be stored, + // we can't store the entire context. + if (auto spec = as<IRSpecialize>(as<IRPtrTypeBase>(var->getDataType())->getValueType())) + { + for (UInt i = 0; i < spec->getArgCount(); i++) + { + if (!canTypeBeStored(spec->getArg(i))) + return false; + } + } + } + + auto storeUse = findLatestUniqueWriteUse(var); + if (storeUse) + { + if (!canTypeBeStored(as<IRPtrTypeBase>(var->getDataType())->getValueType())) + return false; + if (auto callUser = as<IRCall>(storeUse->getUser())) + { + // If the var is being written to by a call, the decision + // of the var will be the same as the decision for the call. + return shouldStoreInst(callUser); + } + // Default behavior is to store if we can. + return true; + } + // If the var has never been written to, don't store it. + return false; +} + bool canRecompute(IRUse* use) { if (auto load = as<IRLoad>(use->get())) diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 62cb0b841..3b3224e2f 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -919,6 +919,7 @@ bool canTypeBeStored(IRInst* type) case kIROp_FloatType: case kIROp_VectorType: case kIROp_MatrixType: + case kIROp_BackwardDiffIntermediateContextType: return true; case kIROp_AttributedType: return canTypeBeStored(type->getOperand(0)); diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index b9ce2893a..f04012112 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -228,6 +228,12 @@ struct PeepholeContext : InstPassBase { return tryReplace(inst->getOperand(0)); } + else if (inst->getOperand(0) == inst->getOperand(1)) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + return tryReplace(builder.emitDefaultConstruct(inst->getDataType())); + } break; case kIROp_Mul: if (isOne(inst->getOperand(0))) |
