summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-03-15 22:26:58 -0400
committerGitHub <noreply@github.com>2023-03-15 19:26:58 -0700
commit71efd949fa5276e2464416fcf237f8fd2c486281 (patch)
treea5b24cd077f2ecc3f74d4dd4671c8260eb6e9b67 /source/slang/slang-ir-autodiff.cpp
parent38e62199cc75ce34608491c8dd299eb330bde518 (diff)
AD: Primal-Hoisting Rework + Checkpoint Policy Framework (#2702)
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
-rw-r--r--source/slang/slang-ir-autodiff.cpp98
1 files changed, 97 insertions, 1 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 517b9e3ea..a3a7e4b77 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -595,6 +595,7 @@ bool canTypeBeStored(IRInst* type)
case kIROp_FloatType:
case kIROp_VectorType:
case kIROp_MatrixType:
+ case kIROp_AttributedType:
return true;
default:
return false;
@@ -904,7 +905,7 @@ struct AutoDiffPass : public InstPassBase
else
{
IntermediateContextTypeDifferentialInfo diffFieldTypeInfo;
- diffTypes.TryGetValue(field->getDataType(), diffFieldTypeInfo);
+ diffTypes.TryGetValue(field->getFieldType(), diffFieldTypeInfo);
diffFieldWitness = diffFieldTypeInfo.diffWitness;
}
if (diffFieldWitness)
@@ -1429,4 +1430,99 @@ bool finalizeAutoDiffPass(IRModule* module)
return false;
}
+IRBlock* getBlock(IRInst* inst)
+{
+ SLANG_RELEASE_ASSERT(inst);
+
+ if (auto block = as<IRBlock>(inst))
+ return block;
+
+ return getBlock(inst->getParent());
+}
+
+IRInst* getInstInBlock(IRInst* inst)
+{
+ SLANG_RELEASE_ASSERT(inst);
+
+ if (auto block = as<IRBlock>(inst->getParent()))
+ return inst;
+
+ return getInstInBlock(inst->getParent());
+}
+
+UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg)
+{
+ SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(block->getTerminator()));
+
+ auto branchInst = as<IRUnconditionalBranch>(block->getTerminator());
+ List<IRInst*> phiArgs;
+
+ for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++)
+ phiArgs.add(branchInst->getArg(ii));
+
+ phiArgs.add(arg);
+
+ builder->setInsertInto(block);
+ switch (branchInst->getOp())
+ {
+ case kIROp_unconditionalBranch:
+ builder->emitBranch(branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer());
+ break;
+
+ case kIROp_loop:
+ builder->emitLoop(
+ as<IRLoop>(branchInst)->getTargetBlock(),
+ as<IRLoop>(branchInst)->getBreakBlock(),
+ as<IRLoop>(branchInst)->getContinueBlock(),
+ phiArgs.getCount(),
+ phiArgs.getBuffer());
+ break;
+
+ default:
+ SLANG_UNEXPECTED("Unexpected branch-type for phi replacement");
+ }
+
+ branchInst->removeAndDeallocate();
+ return phiArgs.getCount() - 1;
+}
+
+IRUse* findUniqueStoredVal(IRVar* var)
+{
+ if (isDerivativeContextVar(var))
+ {
+ IRUse* primalCallUse = nullptr;
+ for (auto use = var->firstUse; use; use = use->nextUse)
+ {
+ if (auto callInst = as<IRCall>(use->getUser()))
+ {
+ // Should not see more than one IRCall. If we do
+ // we'll need to pick the primal call.
+ //
+ SLANG_RELEASE_ASSERT(!primalCallUse);
+ primalCallUse = use;
+ }
+ }
+ return primalCallUse;
+ }
+ else
+ {
+ IRUse* storeUse = nullptr;
+ for (auto use = var->firstUse; use; use = use->nextUse)
+ {
+ if (auto storeInst = as<IRStore>(use->getUser()))
+ {
+ // Should not see more than one IRStore
+ SLANG_RELEASE_ASSERT(!storeUse);
+ storeUse = use;
+ }
+ }
+ return storeUse;
+ }
+}
+
+bool isDerivativeContextVar(IRVar* var)
+{
+ return var->findDecoration<IRBackwardDerivativePrimalContextDecoration>();
+}
+
}