summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-10 12:42:55 -0800
committerGitHub <noreply@github.com>2023-01-10 12:42:55 -0800
commit2f422087ed04940f6b6b351605e61d48ce1989ce (patch)
tree522f8027173732d903a906081238b12863d73fb8 /source/slang/slang-check-decl.cpp
parenteb813fbd8750ed1ab66d73f5fa29ae8f2407e8af (diff)
Nested bwd-diff func call context save/restore. (#2584)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp7
1 files changed, 4 insertions, 3 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 7c8e320c4..b8732a67f 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -5672,7 +5672,8 @@ namespace Slang
{
// Requirement for backward derivative.
auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
- auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef)));
+ auto originalFuncType = getFuncType(m_astBuilder, declRef);
+ auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(originalFuncType));
{
auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>();
cloneModifiers(reqDecl, decl);
@@ -5704,8 +5705,8 @@ namespace Slang
auto reqDecl = m_astBuilder->create<BackwardDerivativePrimalRequirementDecl>();
cloneModifiers(reqDecl, decl);
FuncType* primalFuncType = m_astBuilder->create<FuncType>();
- primalFuncType->resultType = diffFuncType->resultType;
- primalFuncType->paramTypes.addRange(diffFuncType->paramTypes);
+ primalFuncType->resultType = originalFuncType->resultType;
+ primalFuncType->paramTypes.addRange(originalFuncType->paramTypes);
auto outType = m_astBuilder->getOutType(intermediateType);
primalFuncType->paramTypes.add(outType);
setFuncTypeIntoRequirementDecl(reqDecl, primalFuncType);