summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-08-06 18:24:04 -0400
committerGitHub <noreply@github.com>2024-08-06 18:24:04 -0400
commit366c9b4526b4b940c8aafce459d6784211e862bc (patch)
tree9f717b50095895c9471f8e4cc3a49850c0455892 /source/slang
parent33e9de0ef2aa04e7681769104dbe524f8b68525a (diff)
Fix auto-diff synthesized method naming conventions (#4714)
* Fix auto-diff synthesized method naming conventions * Update tests; remove unused var
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp23
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp2
2 files changed, 17 insertions, 8 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index f4a34c5aa..a1fa5f21a 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -1079,6 +1079,20 @@ IRInst* getActualInstToTranscribe(IRInst* inst)
return inst;
}
+void handleNameHint(IRBuilder* builder, IRInst* primal, IRInst* diff)
+{
+ // Ignore types that already have a name hint.
+ if (as<IRType>(diff) && diff->findDecoration<IRNameHintDecoration>())
+ return;
+
+ if (auto nameHint = primal->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder sb;
+ sb << "s_diff_" << nameHint->getName();
+ builder->addNameHintDecoration(diff, sb.getUnownedSlice());
+ }
+}
+
IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst)
{
// If a differential instruction is already mapped for
@@ -1099,7 +1113,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
instsInProgress.remove(origInst);
- if (auto primalInst = pair.primal)
+ if (pair.primal)
{
mapPrimalInst(origInst, pair.primal);
mapDifferentialInst(origInst, pair.differential);
@@ -1124,12 +1138,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
break;
default:
// Generate name hint for the inst.
- if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>())
- {
- StringBuilder sb;
- sb << "s_diff_" << primalNameHint->getName();
- builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice());
- }
+ handleNameHint(builder, pair.primal, pair.differential);
// Automatically tag the primal and differential results
// if they haven't already been handled by the
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 59653c4ae..9b3e3a324 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -346,7 +346,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
}
if (auto originalNameHint = originalFunc->findDecoration<IRNameHintDecoration>())
{
- auto primalName = String("s_bwd_primal_") + UnownedStringSlice(originalNameHint->getName());
+ auto primalName = String("s_primal_ctx_") + UnownedStringSlice(originalNameHint->getName());
builder.addNameHintDecoration(primalFunc, builder.getStringValue(primalName.getUnownedSlice()));
}