From 86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 8 Mar 2023 21:52:34 -0800 Subject: Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691) * Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. * Fix * Fix. * Cleanup. --------- Co-authored-by: Yong He --- source/slang/slang-ir-autodiff-unzip.cpp | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) (limited to 'source/slang/slang-ir-autodiff-unzip.cpp') diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index a05fe7044..2347c7a8f 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -197,8 +197,18 @@ struct ExtractPrimalFuncContext bool shouldStoreVar(IRVar* var) { // Always store intermediate context var. - if (var->findDecoration()) + if (auto typeDecor = var->findDecoration()) { + // If we are specializing a callee's intermediate context with types that can't be stored, + // we can't store the entire context. + if (auto spec = as(as(var->getDataType())->getValueType())) + { + for (UInt i = 0; i < spec->getArgCount(); i++) + { + if (!canTypeBeStored(spec->getArg(i)->getDataType())) + return false; + } + } return true; } @@ -212,7 +222,7 @@ struct ExtractPrimalFuncContext // 2. Does the var have a store // - return (doesInstHaveDiffUse(var) && doesInstHaveStore(var)); + return (doesInstHaveDiffUse(var) && doesInstHaveStore(var) && canTypeBeStored(as(var->getDataType())->getValueType())); } bool shouldStoreInst(IRInst* inst) @@ -222,7 +232,7 @@ struct ExtractPrimalFuncContext return false; } - if (!canInstBeStored(inst)) + if (!canTypeBeStored(inst->getDataType())) return false; // Never store certain opcodes. @@ -246,6 +256,9 @@ struct ExtractPrimalFuncContext case kIROp_MakeOptionalValue: case kIROp_DifferentialPairGetDifferential: case kIROp_DifferentialPairGetPrimal: + case kIROp_ExtractExistentialValue: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialWitnessTable: return false; case kIROp_GetElement: case kIROp_FieldExtract: @@ -560,7 +573,12 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( if (inst->getOp() == kIROp_Call) { // The primal calls should be marked as no side effect so they can be DCE'd if possible. - builder.addSimpleDecoration(inst); + // We can only do so if the intermediate context of the callee is stored. + if (primalCtx->getBackwardDerivativePrimalContextVar() + ->findDecoration()) + { + builder.addSimpleDecoration(inst); + } } } } -- cgit v1.2.3