From 71efd949fa5276e2464416fcf237f8fd2c486281 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 15 Mar 2023 22:26:58 -0400 Subject: AD: Primal-Hoisting Rework + Checkpoint Policy Framework (#2702) --- source/slang/slang-ir-autodiff.cpp | 98 +++++++++++++++++++++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) (limited to 'source/slang/slang-ir-autodiff.cpp') diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 517b9e3ea..a3a7e4b77 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -595,6 +595,7 @@ bool canTypeBeStored(IRInst* type) case kIROp_FloatType: case kIROp_VectorType: case kIROp_MatrixType: + case kIROp_AttributedType: return true; default: return false; @@ -904,7 +905,7 @@ struct AutoDiffPass : public InstPassBase else { IntermediateContextTypeDifferentialInfo diffFieldTypeInfo; - diffTypes.TryGetValue(field->getDataType(), diffFieldTypeInfo); + diffTypes.TryGetValue(field->getFieldType(), diffFieldTypeInfo); diffFieldWitness = diffFieldTypeInfo.diffWitness; } if (diffFieldWitness) @@ -1429,4 +1430,99 @@ bool finalizeAutoDiffPass(IRModule* module) return false; } +IRBlock* getBlock(IRInst* inst) +{ + SLANG_RELEASE_ASSERT(inst); + + if (auto block = as(inst)) + return block; + + return getBlock(inst->getParent()); +} + +IRInst* getInstInBlock(IRInst* inst) +{ + SLANG_RELEASE_ASSERT(inst); + + if (auto block = as(inst->getParent())) + return inst; + + return getInstInBlock(inst->getParent()); +} + +UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg) +{ + SLANG_RELEASE_ASSERT(as(block->getTerminator())); + + auto branchInst = as(block->getTerminator()); + List phiArgs; + + for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++) + phiArgs.add(branchInst->getArg(ii)); + + phiArgs.add(arg); + + builder->setInsertInto(block); + switch (branchInst->getOp()) + { + case kIROp_unconditionalBranch: + builder->emitBranch(branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer()); + break; + + case kIROp_loop: + builder->emitLoop( + as(branchInst)->getTargetBlock(), + as(branchInst)->getBreakBlock(), + as(branchInst)->getContinueBlock(), + phiArgs.getCount(), + phiArgs.getBuffer()); + break; + + default: + SLANG_UNEXPECTED("Unexpected branch-type for phi replacement"); + } + + branchInst->removeAndDeallocate(); + return phiArgs.getCount() - 1; +} + +IRUse* findUniqueStoredVal(IRVar* var) +{ + if (isDerivativeContextVar(var)) + { + IRUse* primalCallUse = nullptr; + for (auto use = var->firstUse; use; use = use->nextUse) + { + if (auto callInst = as(use->getUser())) + { + // Should not see more than one IRCall. If we do + // we'll need to pick the primal call. + // + SLANG_RELEASE_ASSERT(!primalCallUse); + primalCallUse = use; + } + } + return primalCallUse; + } + else + { + IRUse* storeUse = nullptr; + for (auto use = var->firstUse; use; use = use->nextUse) + { + if (auto storeInst = as(use->getUser())) + { + // Should not see more than one IRStore + SLANG_RELEASE_ASSERT(!storeUse); + storeUse = use; + } + } + return storeUse; + } +} + +bool isDerivativeContextVar(IRVar* var) +{ + return var->findDecoration(); +} + } -- cgit v1.2.3