diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-08-06 18:24:04 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-06 18:24:04 -0400 |
| commit | 366c9b4526b4b940c8aafce459d6784211e862bc (patch) | |
| tree | 9f717b50095895c9471f8e4cc3a49850c0455892 | |
| parent | 33e9de0ef2aa04e7681769104dbe524f8b68525a (diff) | |
Fix auto-diff synthesized method naming conventions (#4714)
* Fix auto-diff synthesized method naming conventions
* Update tests; remove unused var
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 23 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 2 | ||||
| -rw-r--r-- | tests/autodiff/reverse-checkpoint-1.slang | 2 | ||||
| -rw-r--r-- | tests/autodiff/reverse-checkpoint-2.slang | 2 |
4 files changed, 19 insertions, 10 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index f4a34c5aa..a1fa5f21a 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -1079,6 +1079,20 @@ IRInst* getActualInstToTranscribe(IRInst* inst) return inst; } +void handleNameHint(IRBuilder* builder, IRInst* primal, IRInst* diff) +{ + // Ignore types that already have a name hint. + if (as<IRType>(diff) && diff->findDecoration<IRNameHintDecoration>()) + return; + + if (auto nameHint = primal->findDecoration<IRNameHintDecoration>()) + { + StringBuilder sb; + sb << "s_diff_" << nameHint->getName(); + builder->addNameHintDecoration(diff, sb.getUnownedSlice()); + } +} + IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst) { // If a differential instruction is already mapped for @@ -1099,7 +1113,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst instsInProgress.remove(origInst); - if (auto primalInst = pair.primal) + if (pair.primal) { mapPrimalInst(origInst, pair.primal); mapDifferentialInst(origInst, pair.differential); @@ -1124,12 +1138,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst break; default: // Generate name hint for the inst. - if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>()) - { - StringBuilder sb; - sb << "s_diff_" << primalNameHint->getName(); - builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); - } + handleNameHint(builder, pair.primal, pair.differential); // Automatically tag the primal and differential results // if they haven't already been handled by the diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 59653c4ae..9b3e3a324 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -346,7 +346,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( } if (auto originalNameHint = originalFunc->findDecoration<IRNameHintDecoration>()) { - auto primalName = String("s_bwd_primal_") + UnownedStringSlice(originalNameHint->getName()); + auto primalName = String("s_primal_ctx_") + UnownedStringSlice(originalNameHint->getName()); builder.addNameHintDecoration(primalFunc, builder.getStringValue(primalName.getUnownedSlice())); } diff --git a/tests/autodiff/reverse-checkpoint-1.slang b/tests/autodiff/reverse-checkpoint-1.slang index beb983b3b..517297013 100644 --- a/tests/autodiff/reverse-checkpoint-1.slang +++ b/tests/autodiff/reverse-checkpoint-1.slang @@ -29,7 +29,7 @@ float f(int p, float x) // Check that there are no calls to primal_g in bwd_f. // CHECK: void s_bwd_f_{{[0-9]+}} -// CHECK-NOT: {{[_a-zA-Z0-9]+}} = s_bwd_primal_g_{{[0-9]+}} +// CHECK-NOT: {{[_a-zA-Z0-9]+}} = s_primal_ctx_g_{{[0-9]+}} // CHECK: return diff --git a/tests/autodiff/reverse-checkpoint-2.slang b/tests/autodiff/reverse-checkpoint-2.slang index 68ff62176..8a7262aa4 100644 --- a/tests/autodiff/reverse-checkpoint-2.slang +++ b/tests/autodiff/reverse-checkpoint-2.slang @@ -29,7 +29,7 @@ float f(int p, float x) // Check that there are no calls to primal_g in bwd_f. // CHECK: void s_bwd_prop_f_{{[0-9]+}} -// CHECK: {{[_a-zA-Z0-9]+}} = s_bwd_primal_g_{{[0-9]+}} +// CHECK: {{[_a-zA-Z0-9]+}} = s_primal_ctx_g_{{[0-9]+}} // CHECK: return |
