diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-02-17 18:02:58 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-17 15:02:58 -0800 |
| commit | 92ccc8f17881d010f399b63aee80ba20bdc7095c (patch) | |
| tree | dbbb6684ac90e36786c3973393515f8d16f0d5f6 | |
| parent | 5cd39d1527f87ebab966cbd9c136b93058a709bc (diff) | |
AD: More legacy type handling cleanup + user-defined reverse-mode fix (#2662)
* WIP: Remove all legacy type checking
* Fixed issue with user-defined backward derivatives not bypassing the AD process
---------
Co-authored-by: Yong He <yonghe@outlook.com>
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 5 | ||||
| -rw-r--r-- | tests/autodiff/backward-diff-check.slang | 1 |
4 files changed, 7 insertions, 10 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index fa1ca8519..ef6178976 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -295,9 +295,6 @@ namespace Slang // Create an empty func to represent the transcribed func of `origFunc`. InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc) { - if (auto bwdDiffFunc = findExistingDiffFunc(origFunc)) - return InstPair(origFunc, bwdDiffFunc); - if (!isBackwardDifferentiableFunc(origFunc)) return InstPair(nullptr, nullptr); @@ -379,11 +376,15 @@ namespace Slang InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) { + if (auto bwdDiffFunc = findExistingDiffFunc(origFunc)) + return InstPair(origFunc, bwdDiffFunc); + auto header = transcribeFuncHeaderImpl(inBuilder, origFunc); if (!header.differential) return header; - + IRBuilder builder = *inBuilder; + builder.setInsertInto(header.differential); builder.emitBlock(); auto origFuncType = as<IRFuncType>(origFunc->getFullType()); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 7ab0ee692..a8e06bf91 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -879,7 +879,7 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst) { - // If a differential intstruction is already mapped for + // If a differential instruction is already mapped for // this original inst, return that. // if (auto diffInst = lookupDiffInst(origInst, nullptr)) diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 7bcd4c90b..f57fb2974 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -170,11 +170,6 @@ struct DifferentiableTypeConformanceContext { switch (origType->getOp()) { - case kIROp_FloatType: - case kIROp_HalfType: - case kIROp_DoubleType: - return origType; - case kIROp_ArrayType: { auto diffElementType = (IRType*)getDifferentialForType( diff --git a/tests/autodiff/backward-diff-check.slang b/tests/autodiff/backward-diff-check.slang index 2718f31f1..0e44cfece 100644 --- a/tests/autodiff/backward-diff-check.slang +++ b/tests/autodiff/backward-diff-check.slang @@ -14,6 +14,7 @@ float test() } float noDiffFunc(float x) { return x; } + [BackwardDerivativeOf(test)] void d_test(float dOut) { |
