summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-rev.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-12-19 11:47:19 -0800
committerGitHub <noreply@github.com>2022-12-19 11:47:19 -0800
commit216dfba0af66210a46ef0df18beb73d975fdf727 (patch)
treef397ea5bf8d47d7a5d90dc95edfb472f2e49d762 /source/slang/slang-ir-autodiff-rev.cpp
parent36220da1e29c891972fef32c8575c15f868b9959 (diff)
Separate primal computations from unzipped function into an explicit function. (#2569)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp105
1 files changed, 101 insertions, 4 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index c7fbc415a..56002231a 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -491,6 +491,98 @@ struct BackwardDiffTranscriber
builder.emitBranch(firstBlock);
}
+ void cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType)
+ {
+ IRStructType* structType = as<IRStructType>(intermediateType);
+ if (!structType)
+ {
+ auto genType = as<IRGeneric>(intermediateType);
+ structType = as<IRStructType>(findGenericReturnVal(genType));
+ SLANG_RELEASE_ASSERT(structType);
+ }
+
+ // Collect fields that are never fetched by reverse func.
+ OrderedHashSet<IRStructKey*> fieldsToCleanup;
+ for (auto children : structType->getChildren())
+ {
+ if (auto field = as<IRStructField>(children))
+ {
+ auto structKey = field->getKey();
+ bool usedByRevFunc = false;
+ for (auto use = structKey->firstUse; use; use = use->nextUse)
+ {
+ if (isChildInstOf(use->getUser(), func))
+ {
+ usedByRevFunc = true;
+ break;
+ }
+ }
+ if (!usedByRevFunc)
+ {
+ List<IRInst*> users;
+ for (auto use = structKey->firstUse; use; use = use->nextUse)
+ {
+ users.add(use->getUser());
+ }
+ for (auto user : users)
+ {
+ if (!isChildInstOf(user, primalFunc))
+ continue;
+ if (auto addr = as<IRFieldAddress>(user))
+ {
+ if (addr->hasMoreThanOneUse())
+ continue;
+ if (addr->firstUse)
+ {
+ if (addr->firstUse->getUser()->getOp() == kIROp_Store)
+ {
+ addr->firstUse->getUser()->removeAndDeallocate();
+ }
+ addr->removeAndDeallocate();
+ }
+ }
+ }
+
+ bool hasNonTrivialUse = false;
+ for (auto use = structKey->firstUse; use; use = use->nextUse)
+ {
+ switch (use->getUser()->getOp())
+ {
+ case kIROp_PrimalValueStructKeyDecoration:
+ case kIROp_StructField:
+ continue;
+ default:
+ hasNonTrivialUse = true;
+ break;
+ }
+ }
+ if (!hasNonTrivialUse)
+ {
+ fieldsToCleanup.Add(structKey);
+ }
+ }
+ }
+ }
+
+ // Actually remove fields from struct.
+ for (auto children : structType->getChildren())
+ {
+ if (auto field = as<IRStructField>(children))
+ {
+ if (fieldsToCleanup.Contains(field->getKey()))
+ {
+ auto key = field->getKey();
+ List<IRInst*> keyUsers;
+ for (auto use = key->firstUse; use; use = use->nextUse)
+ keyUsers.add(use->getUser());
+ for (auto keyUser : keyUsers)
+ keyUser->removeAndDeallocate();
+ key->removeAndDeallocate();
+ }
+ }
+ }
+ }
+
// Transcribe a function definition.
InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
{
@@ -520,12 +612,9 @@ struct BackwardDiffTranscriber
// second block of the unzipped function.
//
IRFunc* unzippedFwdDiffFunc = diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
-
+
// Clone the primal blocks from unzippedFwdDiffFunc
// to the reverse-mode function.
- // TODO: This is the spot where we can make a decision to split
- // the primal and differential into two different funcitons
- // instead of two blocks in the same function.
//
// Special care needs to be taken for the first block since it holds the parameters
@@ -547,6 +636,11 @@ struct BackwardDiffTranscriber
block->insertAtEnd(diffFunc);
}
+ // Extracts the primal computations into its own func, and replace the primal insts
+ // with the intermediate results computed from the extracted func.
+ IRInst* intermediateType = nullptr;
+ auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffFunc, unzippedFwdDiffFunc, intermediateType);
+
// Transpose the first block (parameter block)
transcribeParameterBlock(builder, diffFunc);
@@ -563,6 +657,9 @@ struct BackwardDiffTranscriber
unzippedFwdDiffFunc->removeAndDeallocate();
fwdDiffFunc->removeAndDeallocate();
+ eliminateDeadCode(diffFunc);
+ cleanUpUnusedPrimalIntermediate(diffFunc, extractedPrimalFunc, intermediateType);
+
return InstPair(primalFunc, diffFunc);
}