From 2f422087ed04940f6b6b351605e61d48ce1989ce Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 10 Jan 2023 12:42:55 -0800 Subject: Nested bwd-diff func call context save/restore. (#2584) Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 7c8e320c4..b8732a67f 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -5672,7 +5672,8 @@ namespace Slang { // Requirement for backward derivative. auto declRef = DeclRef(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); - auto diffFuncType = as(getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef))); + auto originalFuncType = getFuncType(m_astBuilder, declRef); + auto diffFuncType = as(getBackwardDiffFuncType(originalFuncType)); { auto reqDecl = m_astBuilder->create(); cloneModifiers(reqDecl, decl); @@ -5704,8 +5705,8 @@ namespace Slang auto reqDecl = m_astBuilder->create(); cloneModifiers(reqDecl, decl); FuncType* primalFuncType = m_astBuilder->create(); - primalFuncType->resultType = diffFuncType->resultType; - primalFuncType->paramTypes.addRange(diffFuncType->paramTypes); + primalFuncType->resultType = originalFuncType->resultType; + primalFuncType->paramTypes.addRange(originalFuncType->paramTypes); auto outType = m_astBuilder->getOutType(intermediateType); primalFuncType->paramTypes.add(outType); setFuncTypeIntoRequirementDecl(reqDecl, primalFuncType); -- cgit v1.2.3