diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-10 12:42:55 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-10 12:42:55 -0800 |
| commit | 2f422087ed04940f6b6b351605e61d48ce1989ce (patch) | |
| tree | 522f8027173732d903a906081238b12863d73fb8 /source/slang/slang-check-decl.cpp | |
| parent | eb813fbd8750ed1ab66d73f5fa29ae8f2407e8af (diff) | |
Nested bwd-diff func call context save/restore. (#2584)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 7c8e320c4..b8732a67f 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -5672,7 +5672,8 @@ namespace Slang { // Requirement for backward derivative. auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); - auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef))); + auto originalFuncType = getFuncType(m_astBuilder, declRef); + auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(originalFuncType)); { auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>(); cloneModifiers(reqDecl, decl); @@ -5704,8 +5705,8 @@ namespace Slang auto reqDecl = m_astBuilder->create<BackwardDerivativePrimalRequirementDecl>(); cloneModifiers(reqDecl, decl); FuncType* primalFuncType = m_astBuilder->create<FuncType>(); - primalFuncType->resultType = diffFuncType->resultType; - primalFuncType->paramTypes.addRange(diffFuncType->paramTypes); + primalFuncType->resultType = originalFuncType->resultType; + primalFuncType->paramTypes.addRange(originalFuncType->paramTypes); auto outType = m_astBuilder->getOutType(intermediateType); primalFuncType->paramTypes.add(outType); setFuncTypeIntoRequirementDecl(reqDecl, primalFuncType); |
