diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-08 21:52:34 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-08 21:52:34 -0800 |
| commit | 86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 (patch) | |
| tree | b4f9eb6cb1eea88145fde0bd1f670a8803120257 /source/slang/slang-ir-autodiff-unzip.cpp | |
| parent | 257733f328f38a763c8b0c8830ff4c0d34ec9491 (diff) | |
Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691)
* Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`.
* Fix
* Fix.
* Cleanup.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-unzip.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index a05fe7044..2347c7a8f 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -197,8 +197,18 @@ struct ExtractPrimalFuncContext bool shouldStoreVar(IRVar* var) { // Always store intermediate context var. - if (var->findDecoration<IRBackwardDerivativePrimalContextDecoration>()) + if (auto typeDecor = var->findDecoration<IRBackwardDerivativePrimalContextDecoration>()) { + // If we are specializing a callee's intermediate context with types that can't be stored, + // we can't store the entire context. + if (auto spec = as<IRSpecialize>(as<IRPtrTypeBase>(var->getDataType())->getValueType())) + { + for (UInt i = 0; i < spec->getArgCount(); i++) + { + if (!canTypeBeStored(spec->getArg(i)->getDataType())) + return false; + } + } return true; } @@ -212,7 +222,7 @@ struct ExtractPrimalFuncContext // 2. Does the var have a store // - return (doesInstHaveDiffUse(var) && doesInstHaveStore(var)); + return (doesInstHaveDiffUse(var) && doesInstHaveStore(var) && canTypeBeStored(as<IRPtrTypeBase>(var->getDataType())->getValueType())); } bool shouldStoreInst(IRInst* inst) @@ -222,7 +232,7 @@ struct ExtractPrimalFuncContext return false; } - if (!canInstBeStored(inst)) + if (!canTypeBeStored(inst->getDataType())) return false; // Never store certain opcodes. @@ -246,6 +256,9 @@ struct ExtractPrimalFuncContext case kIROp_MakeOptionalValue: case kIROp_DifferentialPairGetDifferential: case kIROp_DifferentialPairGetPrimal: + case kIROp_ExtractExistentialValue: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialWitnessTable: return false; case kIROp_GetElement: case kIROp_FieldExtract: @@ -560,7 +573,12 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( if (inst->getOp() == kIROp_Call) { // The primal calls should be marked as no side effect so they can be DCE'd if possible. - builder.addSimpleDecoration<IRNoSideEffectDecoration>(inst); + // We can only do so if the intermediate context of the callee is stored. + if (primalCtx->getBackwardDerivativePrimalContextVar() + ->findDecoration<IRPrimalValueStructKeyDecoration>()) + { + builder.addSimpleDecoration<IRNoSideEffectDecoration>(inst); + } } } } |
