diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 2 |
4 files changed, 21 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index e78217fe3..444816ff7 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1508,6 +1508,9 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I builder.addForwardDifferentiableDecoration(diffFunc); if (isBackwardDifferentiableFunc(origFunc)) builder.addBackwardDifferentiableDecoration(diffFunc); + + // Transfer checkpoint hint decorations + copyCheckpointHints(&builder, origFunc, diffFunc); // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index ecc36d6ba..70c43cdcb 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -340,6 +340,9 @@ namespace Slang builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); } + // Transfer checkpoint hint decorations + copyCheckpointHints(&builder, origFunc, diffFunc); + // Mark the generated derivative function itself as differentiable. builder.addBackwardDifferentiableDecoration(diffFunc); // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 3b3224e2f..f6a977994 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -749,6 +749,19 @@ IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness(I SLANG_UNIMPLEMENTED_X("TODO: Implement"); } + +void copyCheckpointHints(IRBuilder* builder, IRGlobalValueWithCode* oldInst, IRGlobalValueWithCode* newInst) +{ + for (auto decor = oldInst->getFirstDecoration(); decor; decor = decor->getNextDecoration()) + { + if (auto chkHint = as<IRCheckpointHintDecoration>(decor)) + { + SLANG_ASSERT(chkHint->getOperandCount() == 0); + builder->addDecoration(newInst, chkHint->getOp()); + } + } +} + void stripDerivativeDecorations(IRInst* inst) { for (auto decor = inst->getFirstDecoration(); decor; ) diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 91b45c5be..fdbf5c65e 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -316,6 +316,8 @@ bool finalizeAutoDiffPass(IRModule* module); // Utility methods +void copyCheckpointHints(IRBuilder* builder, IRGlobalValueWithCode* oldInst, IRGlobalValueWithCode* newInst); + void stripDerivativeDecorations(IRInst* inst); bool isBackwardDifferentiableFunc(IRInst* func); |
