diff options
| author | Yong He <yonghe@outlook.com> | 2022-12-19 11:47:19 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-12-19 11:47:19 -0800 |
| commit | 216dfba0af66210a46ef0df18beb73d975fdf727 (patch) | |
| tree | f397ea5bf8d47d7a5d90dc95edfb472f2e49d762 /source/slang/slang-ir-autodiff-rev.cpp | |
| parent | 36220da1e29c891972fef32c8575c15f868b9959 (diff) | |
Separate primal computations from unzipped function into an explicit function. (#2569)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 105 |
1 files changed, 101 insertions, 4 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index c7fbc415a..56002231a 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -491,6 +491,98 @@ struct BackwardDiffTranscriber builder.emitBranch(firstBlock); } + void cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType) + { + IRStructType* structType = as<IRStructType>(intermediateType); + if (!structType) + { + auto genType = as<IRGeneric>(intermediateType); + structType = as<IRStructType>(findGenericReturnVal(genType)); + SLANG_RELEASE_ASSERT(structType); + } + + // Collect fields that are never fetched by reverse func. + OrderedHashSet<IRStructKey*> fieldsToCleanup; + for (auto children : structType->getChildren()) + { + if (auto field = as<IRStructField>(children)) + { + auto structKey = field->getKey(); + bool usedByRevFunc = false; + for (auto use = structKey->firstUse; use; use = use->nextUse) + { + if (isChildInstOf(use->getUser(), func)) + { + usedByRevFunc = true; + break; + } + } + if (!usedByRevFunc) + { + List<IRInst*> users; + for (auto use = structKey->firstUse; use; use = use->nextUse) + { + users.add(use->getUser()); + } + for (auto user : users) + { + if (!isChildInstOf(user, primalFunc)) + continue; + if (auto addr = as<IRFieldAddress>(user)) + { + if (addr->hasMoreThanOneUse()) + continue; + if (addr->firstUse) + { + if (addr->firstUse->getUser()->getOp() == kIROp_Store) + { + addr->firstUse->getUser()->removeAndDeallocate(); + } + addr->removeAndDeallocate(); + } + } + } + + bool hasNonTrivialUse = false; + for (auto use = structKey->firstUse; use; use = use->nextUse) + { + switch (use->getUser()->getOp()) + { + case kIROp_PrimalValueStructKeyDecoration: + case kIROp_StructField: + continue; + default: + hasNonTrivialUse = true; + break; + } + } + if (!hasNonTrivialUse) + { + fieldsToCleanup.Add(structKey); + } + } + } + } + + // Actually remove fields from struct. + for (auto children : structType->getChildren()) + { + if (auto field = as<IRStructField>(children)) + { + if (fieldsToCleanup.Contains(field->getKey())) + { + auto key = field->getKey(); + List<IRInst*> keyUsers; + for (auto use = key->firstUse; use; use = use->nextUse) + keyUsers.add(use->getUser()); + for (auto keyUser : keyUsers) + keyUser->removeAndDeallocate(); + key->removeAndDeallocate(); + } + } + } + } + // Transcribe a function definition. InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) { @@ -520,12 +612,9 @@ struct BackwardDiffTranscriber // second block of the unzipped function. // IRFunc* unzippedFwdDiffFunc = diffUnzipPass->unzipDiffInsts(fwdDiffFunc); - + // Clone the primal blocks from unzippedFwdDiffFunc // to the reverse-mode function. - // TODO: This is the spot where we can make a decision to split - // the primal and differential into two different funcitons - // instead of two blocks in the same function. // // Special care needs to be taken for the first block since it holds the parameters @@ -547,6 +636,11 @@ struct BackwardDiffTranscriber block->insertAtEnd(diffFunc); } + // Extracts the primal computations into its own func, and replace the primal insts + // with the intermediate results computed from the extracted func. + IRInst* intermediateType = nullptr; + auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffFunc, unzippedFwdDiffFunc, intermediateType); + // Transpose the first block (parameter block) transcribeParameterBlock(builder, diffFunc); @@ -563,6 +657,9 @@ struct BackwardDiffTranscriber unzippedFwdDiffFunc->removeAndDeallocate(); fwdDiffFunc->removeAndDeallocate(); + eliminateDeadCode(diffFunc); + cleanUpUnusedPrimalIntermediate(diffFunc, extractedPrimalFunc, intermediateType); + return InstPair(primalFunc, diffFunc); } |
