From 366c9b4526b4b940c8aafce459d6784211e862bc Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 6 Aug 2024 18:24:04 -0400 Subject: Fix auto-diff synthesized method naming conventions (#4714) * Fix auto-diff synthesized method naming conventions * Update tests; remove unused var --- .../slang/slang-ir-autodiff-transcriber-base.cpp | 23 +++++++++++++++------- source/slang/slang-ir-autodiff-unzip.cpp | 2 +- 2 files changed, 17 insertions(+), 8 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index f4a34c5aa..a1fa5f21a 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -1079,6 +1079,20 @@ IRInst* getActualInstToTranscribe(IRInst* inst) return inst; } +void handleNameHint(IRBuilder* builder, IRInst* primal, IRInst* diff) +{ + // Ignore types that already have a name hint. + if (as(diff) && diff->findDecoration()) + return; + + if (auto nameHint = primal->findDecoration()) + { + StringBuilder sb; + sb << "s_diff_" << nameHint->getName(); + builder->addNameHintDecoration(diff, sb.getUnownedSlice()); + } +} + IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst) { // If a differential instruction is already mapped for @@ -1099,7 +1113,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst instsInProgress.remove(origInst); - if (auto primalInst = pair.primal) + if (pair.primal) { mapPrimalInst(origInst, pair.primal); mapDifferentialInst(origInst, pair.differential); @@ -1124,12 +1138,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst break; default: // Generate name hint for the inst. - if (auto primalNameHint = primalInst->findDecoration()) - { - StringBuilder sb; - sb << "s_diff_" << primalNameHint->getName(); - builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); - } + handleNameHint(builder, pair.primal, pair.differential); // Automatically tag the primal and differential results // if they haven't already been handled by the diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 59653c4ae..9b3e3a324 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -346,7 +346,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( } if (auto originalNameHint = originalFunc->findDecoration()) { - auto primalName = String("s_bwd_primal_") + UnownedStringSlice(originalNameHint->getName()); + auto primalName = String("s_primal_ctx_") + UnownedStringSlice(originalNameHint->getName()); builder.addNameHintDecoration(primalFunc, builder.getStringValue(primalName.getUnownedSlice())); } -- cgit v1.2.3