summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-unzip.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-08 21:52:34 -0800
committerGitHub <noreply@github.com>2023-03-08 21:52:34 -0800
commit86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 (patch)
treeb4f9eb6cb1eea88145fde0bd1f670a8803120257 /source/slang/slang-ir-autodiff-unzip.cpp
parent257733f328f38a763c8b0c8830ff4c0d34ec9491 (diff)
Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691)
* Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. * Fix * Fix. * Cleanup. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-unzip.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp26
1 files changed, 22 insertions, 4 deletions
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index a05fe7044..2347c7a8f 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -197,8 +197,18 @@ struct ExtractPrimalFuncContext
bool shouldStoreVar(IRVar* var)
{
// Always store intermediate context var.
- if (var->findDecoration<IRBackwardDerivativePrimalContextDecoration>())
+ if (auto typeDecor = var->findDecoration<IRBackwardDerivativePrimalContextDecoration>())
{
+ // If we are specializing a callee's intermediate context with types that can't be stored,
+ // we can't store the entire context.
+ if (auto spec = as<IRSpecialize>(as<IRPtrTypeBase>(var->getDataType())->getValueType()))
+ {
+ for (UInt i = 0; i < spec->getArgCount(); i++)
+ {
+ if (!canTypeBeStored(spec->getArg(i)->getDataType()))
+ return false;
+ }
+ }
return true;
}
@@ -212,7 +222,7 @@ struct ExtractPrimalFuncContext
// 2. Does the var have a store
//
- return (doesInstHaveDiffUse(var) && doesInstHaveStore(var));
+ return (doesInstHaveDiffUse(var) && doesInstHaveStore(var) && canTypeBeStored(as<IRPtrTypeBase>(var->getDataType())->getValueType()));
}
bool shouldStoreInst(IRInst* inst)
@@ -222,7 +232,7 @@ struct ExtractPrimalFuncContext
return false;
}
- if (!canInstBeStored(inst))
+ if (!canTypeBeStored(inst->getDataType()))
return false;
// Never store certain opcodes.
@@ -246,6 +256,9 @@ struct ExtractPrimalFuncContext
case kIROp_MakeOptionalValue:
case kIROp_DifferentialPairGetDifferential:
case kIROp_DifferentialPairGetPrimal:
+ case kIROp_ExtractExistentialValue:
+ case kIROp_ExtractExistentialType:
+ case kIROp_ExtractExistentialWitnessTable:
return false;
case kIROp_GetElement:
case kIROp_FieldExtract:
@@ -560,7 +573,12 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
if (inst->getOp() == kIROp_Call)
{
// The primal calls should be marked as no side effect so they can be DCE'd if possible.
- builder.addSimpleDecoration<IRNoSideEffectDecoration>(inst);
+ // We can only do so if the intermediate context of the callee is stored.
+ if (primalCtx->getBackwardDerivativePrimalContextVar()
+ ->findDecoration<IRPrimalValueStructKeyDecoration>())
+ {
+ builder.addSimpleDecoration<IRNoSideEffectDecoration>(inst);
+ }
}
}
}