summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp23
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp2
-rw-r--r--tests/autodiff/reverse-checkpoint-1.slang2
-rw-r--r--tests/autodiff/reverse-checkpoint-2.slang2
4 files changed, 19 insertions, 10 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index f4a34c5aa..a1fa5f21a 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -1079,6 +1079,20 @@ IRInst* getActualInstToTranscribe(IRInst* inst)
return inst;
}
+void handleNameHint(IRBuilder* builder, IRInst* primal, IRInst* diff)
+{
+ // Ignore types that already have a name hint.
+ if (as<IRType>(diff) && diff->findDecoration<IRNameHintDecoration>())
+ return;
+
+ if (auto nameHint = primal->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder sb;
+ sb << "s_diff_" << nameHint->getName();
+ builder->addNameHintDecoration(diff, sb.getUnownedSlice());
+ }
+}
+
IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst)
{
// If a differential instruction is already mapped for
@@ -1099,7 +1113,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
instsInProgress.remove(origInst);
- if (auto primalInst = pair.primal)
+ if (pair.primal)
{
mapPrimalInst(origInst, pair.primal);
mapDifferentialInst(origInst, pair.differential);
@@ -1124,12 +1138,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
break;
default:
// Generate name hint for the inst.
- if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>())
- {
- StringBuilder sb;
- sb << "s_diff_" << primalNameHint->getName();
- builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice());
- }
+ handleNameHint(builder, pair.primal, pair.differential);
// Automatically tag the primal and differential results
// if they haven't already been handled by the
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 59653c4ae..9b3e3a324 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -346,7 +346,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
}
if (auto originalNameHint = originalFunc->findDecoration<IRNameHintDecoration>())
{
- auto primalName = String("s_bwd_primal_") + UnownedStringSlice(originalNameHint->getName());
+ auto primalName = String("s_primal_ctx_") + UnownedStringSlice(originalNameHint->getName());
builder.addNameHintDecoration(primalFunc, builder.getStringValue(primalName.getUnownedSlice()));
}
diff --git a/tests/autodiff/reverse-checkpoint-1.slang b/tests/autodiff/reverse-checkpoint-1.slang
index beb983b3b..517297013 100644
--- a/tests/autodiff/reverse-checkpoint-1.slang
+++ b/tests/autodiff/reverse-checkpoint-1.slang
@@ -29,7 +29,7 @@ float f(int p, float x)
// Check that there are no calls to primal_g in bwd_f.
// CHECK: void s_bwd_f_{{[0-9]+}}
-// CHECK-NOT: {{[_a-zA-Z0-9]+}} = s_bwd_primal_g_{{[0-9]+}}
+// CHECK-NOT: {{[_a-zA-Z0-9]+}} = s_primal_ctx_g_{{[0-9]+}}
// CHECK: return
diff --git a/tests/autodiff/reverse-checkpoint-2.slang b/tests/autodiff/reverse-checkpoint-2.slang
index 68ff62176..8a7262aa4 100644
--- a/tests/autodiff/reverse-checkpoint-2.slang
+++ b/tests/autodiff/reverse-checkpoint-2.slang
@@ -29,7 +29,7 @@ float f(int p, float x)
// Check that there are no calls to primal_g in bwd_f.
// CHECK: void s_bwd_prop_f_{{[0-9]+}}
-// CHECK: {{[_a-zA-Z0-9]+}} = s_bwd_primal_g_{{[0-9]+}}
+// CHECK: {{[_a-zA-Z0-9]+}} = s_primal_ctx_g_{{[0-9]+}}
// CHECK: return