From 87f00a36a123e36b415eeea82e02a8366cc5b881 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 10 Jan 2025 03:16:24 +0530 Subject: [Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866) * Overhauled the auto-diff system for dynamic dispatch * More fixes * remove intermediate dumps * Update slang-ast-type.h * More fixes + add a workaround for existential no-diff * Update reverse-control-flow-3.slang * remove dumps * remove more dumps * Delete working-reverse-control-flow-3.hlsl * Cleanup comments + unused variables * More comment cleanup * Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)` * Fix array of issues in Falcor tests. * Update slang-ir-autodiff-pairs.cpp * More fixes for Falcor image tests * Small fixups. --------- Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 3667a36ba..04d5b7a75 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -9247,12 +9247,16 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* if (!decl->hasModifier()) { // Build decl-ref-type from interface. - auto interfaceType = - DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl)); + auto thisType = DeclRefType::create( + m_astBuilder, + createDefaultSubstitutionsIfNeeded( + m_astBuilder, + this, + makeDeclRef(interfaceDecl->getThisTypeDecl()))); // If the interface is differentiable, make the this type a pair. - if (tryGetDifferentialType(getASTBuilder(), interfaceType)) - reqDecl->diffThisType = getDifferentialPairType(interfaceType); + if (tryGetDifferentialType(getASTBuilder(), thisType)) + reqDecl->diffThisType = getDifferentialPairType(thisType); } auto reqRef = m_astBuilder->create(); @@ -9277,13 +9281,17 @@ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* reqDecl->parentDecl = interfaceDecl; if (!decl->hasModifier()) { - // Build decl-ref-type from interface. - auto interfaceType = - DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl)); + // Build decl-ref-type for this-type. + auto thisType = DeclRefType::create( + m_astBuilder, + createDefaultSubstitutionsIfNeeded( + m_astBuilder, + this, + makeDeclRef(interfaceDecl->getThisTypeDecl()))); // If the interface is differentiable, make the this type a pair. - if (tryGetDifferentialType(getASTBuilder(), interfaceType)) - reqDecl->diffThisType = getDifferentialPairType(interfaceType); + if (tryGetDifferentialType(getASTBuilder(), thisType)) + reqDecl->diffThisType = getDifferentialPairType(thisType); } auto reqRef = m_astBuilder->create(); -- cgit v1.2.3