From 87f00a36a123e36b415eeea82e02a8366cc5b881 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 10 Jan 2025 03:16:24 +0530 Subject: [Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866) * Overhauled the auto-diff system for dynamic dispatch * More fixes * remove intermediate dumps * Update slang-ast-type.h * More fixes + add a workaround for existential no-diff * Update reverse-control-flow-3.slang * remove dumps * remove more dumps * Delete working-reverse-control-flow-3.hlsl * Cleanup comments + unused variables * More comment cleanup * Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)` * Fix array of issues in Falcor tests. * Update slang-ir-autodiff-pairs.cpp * More fixes for Falcor image tests * Small fixups. --------- Co-authored-by: Yong He --- source/slang/slang-emit.cpp | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) (limited to 'source/slang/slang-emit.cpp') diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index b9217de41..cd1b177b2 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -815,7 +815,18 @@ Result linkAndOptimizeIR( bool changed = false; dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-SPECIALIZE"); if (!codeGenContext->isSpecializationDisabled()) - changed |= specializeModule(targetProgram, irModule, codeGenContext->getSink()); + { + // Pre-autodiff, we will attempt to specialize as much as possible. + // + // Note: Lowered dynamic-dispatch code cannot be differentiated correctly due to + // missing information, so we defer that to after the auto-dff step. + // + SpecializationOptions specOptions; + specOptions.lowerWitnessLookups = false; + changed |= + specializeModule(targetProgram, irModule, codeGenContext->getSink(), specOptions); + } + if (codeGenContext->getSink()->getErrorCount() != 0) return SLANG_FAIL; dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE"); @@ -867,9 +878,20 @@ Result linkAndOptimizeIR( reportCheckpointIntermediates(codeGenContext, sink, irModule); // Finalization is always run so AD-related instructions can be removed, - // even the AD pass itself is not run. + // even if the AD pass itself is not run. // finalizeAutoDiffPass(targetProgram, irModule); + eliminateDeadCode(irModule, deadCodeEliminationOptions); + + // After auto-diff, we can perform more aggressive specialization with dynamic-dispatch + // lowering. + // + if (!codeGenContext->isSpecializationDisabled()) + { + SpecializationOptions specOptions; + specOptions.lowerWitnessLookups = true; + specializeModule(targetProgram, irModule, codeGenContext->getSink(), specOptions); + } finalizeSpecialization(irModule); @@ -930,6 +952,8 @@ Result linkAndOptimizeIR( validateIRModuleIfEnabled(codeGenContext, irModule); + inferAnyValueSizeWhereNecessary(targetProgram, irModule); + // If we have any witness tables that are marked as `KeepAlive`, // but are not used for dynamic dispatch, unpin them so we don't // do unnecessary work to lower them. -- cgit v1.2.3