diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-03-15 22:26:58 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-15 19:26:58 -0700 |
| commit | 71efd949fa5276e2464416fcf237f8fd2c486281 (patch) | |
| tree | a5b24cd077f2ecc3f74d4dd4671c8260eb6e9b67 /source/slang/slang-ir-autodiff.cpp | |
| parent | 38e62199cc75ce34608491c8dd299eb330bde518 (diff) | |
AD: Primal-Hoisting Rework + Checkpoint Policy Framework (#2702)
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 98 |
1 files changed, 97 insertions, 1 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 517b9e3ea..a3a7e4b77 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -595,6 +595,7 @@ bool canTypeBeStored(IRInst* type) case kIROp_FloatType: case kIROp_VectorType: case kIROp_MatrixType: + case kIROp_AttributedType: return true; default: return false; @@ -904,7 +905,7 @@ struct AutoDiffPass : public InstPassBase else { IntermediateContextTypeDifferentialInfo diffFieldTypeInfo; - diffTypes.TryGetValue(field->getDataType(), diffFieldTypeInfo); + diffTypes.TryGetValue(field->getFieldType(), diffFieldTypeInfo); diffFieldWitness = diffFieldTypeInfo.diffWitness; } if (diffFieldWitness) @@ -1429,4 +1430,99 @@ bool finalizeAutoDiffPass(IRModule* module) return false; } +IRBlock* getBlock(IRInst* inst) +{ + SLANG_RELEASE_ASSERT(inst); + + if (auto block = as<IRBlock>(inst)) + return block; + + return getBlock(inst->getParent()); +} + +IRInst* getInstInBlock(IRInst* inst) +{ + SLANG_RELEASE_ASSERT(inst); + + if (auto block = as<IRBlock>(inst->getParent())) + return inst; + + return getInstInBlock(inst->getParent()); +} + +UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg) +{ + SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(block->getTerminator())); + + auto branchInst = as<IRUnconditionalBranch>(block->getTerminator()); + List<IRInst*> phiArgs; + + for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++) + phiArgs.add(branchInst->getArg(ii)); + + phiArgs.add(arg); + + builder->setInsertInto(block); + switch (branchInst->getOp()) + { + case kIROp_unconditionalBranch: + builder->emitBranch(branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer()); + break; + + case kIROp_loop: + builder->emitLoop( + as<IRLoop>(branchInst)->getTargetBlock(), + as<IRLoop>(branchInst)->getBreakBlock(), + as<IRLoop>(branchInst)->getContinueBlock(), + phiArgs.getCount(), + phiArgs.getBuffer()); + break; + + default: + SLANG_UNEXPECTED("Unexpected branch-type for phi replacement"); + } + + branchInst->removeAndDeallocate(); + return phiArgs.getCount() - 1; +} + +IRUse* findUniqueStoredVal(IRVar* var) +{ + if (isDerivativeContextVar(var)) + { + IRUse* primalCallUse = nullptr; + for (auto use = var->firstUse; use; use = use->nextUse) + { + if (auto callInst = as<IRCall>(use->getUser())) + { + // Should not see more than one IRCall. If we do + // we'll need to pick the primal call. + // + SLANG_RELEASE_ASSERT(!primalCallUse); + primalCallUse = use; + } + } + return primalCallUse; + } + else + { + IRUse* storeUse = nullptr; + for (auto use = var->firstUse; use; use = use->nextUse) + { + if (auto storeInst = as<IRStore>(use->getUser())) + { + // Should not see more than one IRStore + SLANG_RELEASE_ASSERT(!storeUse); + storeUse = use; + } + } + return storeUse; + } +} + +bool isDerivativeContextVar(IRVar* var) +{ + return var->findDecoration<IRBackwardDerivativePrimalContextDecoration>(); +} + } |
