From 87f00a36a123e36b415eeea82e02a8366cc5b881 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 10 Jan 2025 03:16:24 +0530 Subject: [Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866) * Overhauled the auto-diff system for dynamic dispatch * More fixes * remove intermediate dumps * Update slang-ast-type.h * More fixes + add a workaround for existential no-diff * Update reverse-control-flow-3.slang * remove dumps * remove more dumps * Delete working-reverse-control-flow-3.hlsl * Cleanup comments + unused variables * More comment cleanup * Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)` * Fix array of issues in Falcor tests. * Update slang-ir-autodiff-pairs.cpp * More fixes for Falcor image tests * Small fixups. --------- Co-authored-by: Yong He --- source/slang/slang-ir-specialize.cpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) (limited to 'source/slang/slang-ir-specialize.cpp') diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 50dfa2c6a..40cd40758 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -51,15 +51,17 @@ struct SpecializationContext IRModule* module; DiagnosticSink* sink; TargetProgram* targetProgram; + SpecializationOptions options; bool changed = false; - SpecializationContext(IRModule* inModule, TargetProgram* target) + SpecializationContext(IRModule* inModule, TargetProgram* target, SpecializationOptions options) : workList(*inModule->getContainerPool().getList()) , workListSet(*inModule->getContainerPool().getHashSet()) , cleanInsts(*inModule->getContainerPool().getHashSet()) , module(inModule) , targetProgram(target) + , options(options) { } ~SpecializationContext() @@ -1102,7 +1104,11 @@ struct SpecializationContext // Now we consider lower lookupWitnessMethod insts into dynamic dispatch calls, // which may open up more specialization opportunities. // - iterChanged = lowerWitnessLookup(module, sink); + if (options.lowerWitnessLookups) + { + iterChanged = lowerWitnessLookup(module, sink); + } + if (!iterChanged || sink->getErrorCount()) break; } @@ -2882,10 +2888,14 @@ struct SpecializationContext } }; -bool specializeModule(TargetProgram* target, IRModule* module, DiagnosticSink* sink) +bool specializeModule( + TargetProgram* target, + IRModule* module, + DiagnosticSink* sink, + SpecializationOptions options) { SLANG_PROFILE; - SpecializationContext context(module, target); + SpecializationContext context(module, target, options); context.sink = sink; context.processModule(); return context.changed; -- cgit v1.2.3