summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-02-17 18:02:58 -0500
committerGitHub <noreply@github.com>2023-02-17 15:02:58 -0800
commit92ccc8f17881d010f399b63aee80ba20bdc7095c (patch)
treedbbb6684ac90e36786c3973393515f8d16f0d5f6
parent5cd39d1527f87ebab966cbd9c136b93058a709bc (diff)
AD: More legacy type handling cleanup + user-defined reverse-mode fix (#2662)
* WIP: Remove all legacy type checking * Fixed issue with user-defined backward derivatives not bypassing the AD process --------- Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp9
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp2
-rw-r--r--source/slang/slang-ir-autodiff.h5
-rw-r--r--tests/autodiff/backward-diff-check.slang1
4 files changed, 7 insertions, 10 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index fa1ca8519..ef6178976 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -295,9 +295,6 @@ namespace Slang
// Create an empty func to represent the transcribed func of `origFunc`.
InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc)
{
- if (auto bwdDiffFunc = findExistingDiffFunc(origFunc))
- return InstPair(origFunc, bwdDiffFunc);
-
if (!isBackwardDifferentiableFunc(origFunc))
return InstPair(nullptr, nullptr);
@@ -379,11 +376,15 @@ namespace Slang
InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
{
+ if (auto bwdDiffFunc = findExistingDiffFunc(origFunc))
+ return InstPair(origFunc, bwdDiffFunc);
+
auto header = transcribeFuncHeaderImpl(inBuilder, origFunc);
if (!header.differential)
return header;
-
+
IRBuilder builder = *inBuilder;
+
builder.setInsertInto(header.differential);
builder.emitBlock();
auto origFuncType = as<IRFuncType>(origFunc->getFullType());
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 7ab0ee692..a8e06bf91 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -879,7 +879,7 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene
IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst)
{
- // If a differential intstruction is already mapped for
+ // If a differential instruction is already mapped for
// this original inst, return that.
//
if (auto diffInst = lookupDiffInst(origInst, nullptr))
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 7bcd4c90b..f57fb2974 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -170,11 +170,6 @@ struct DifferentiableTypeConformanceContext
{
switch (origType->getOp())
{
- case kIROp_FloatType:
- case kIROp_HalfType:
- case kIROp_DoubleType:
- return origType;
-
case kIROp_ArrayType:
{
auto diffElementType = (IRType*)getDifferentialForType(
diff --git a/tests/autodiff/backward-diff-check.slang b/tests/autodiff/backward-diff-check.slang
index 2718f31f1..0e44cfece 100644
--- a/tests/autodiff/backward-diff-check.slang
+++ b/tests/autodiff/backward-diff-check.slang
@@ -14,6 +14,7 @@ float test()
}
float noDiffFunc(float x) { return x; }
+
[BackwardDerivativeOf(test)]
void d_test(float dOut)
{