diff options
| author | Yong He <yonghe@outlook.com> | 2023-04-26 17:37:04 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-26 17:37:04 -0700 |
| commit | fc54adee1f7f0ba18591fc84ce5d51ac23afa954 (patch) | |
| tree | 4727ed6109ac50e95c49aadcebc0fb8b95495739 /source/slang/slang-lower-to-ir.cpp | |
| parent | 61eb17b0b556ccc06f65f921bb0a4ea2784c4e20 (diff) | |
Autodiff support for dynamically dispatched generic method. (#2846)
* Autodiff support for dynamically dispatched generic method.
* Fix.
* Support dynamically dispatched generic type.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index d644d01c7..c8a41c7c7 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7429,7 +7429,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } else { - if (auto callableDecl = as<CallableDecl>(requirementDecl)) + CallableDecl* callableDecl = nullptr; + if (auto genDecl = as<GenericDecl>(requirementDecl)) + callableDecl = as<CallableDecl>(genDecl->inner); + else + callableDecl = as<CallableDecl>(requirementDecl); + if (callableDecl) { // Differentiable functions has additional requirements for the derivatives. for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementReferenceDecl>()) @@ -8369,7 +8374,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> LoweredValInfo lowerFuncDeclInContext(IRGenContext* subContext, IRBuilder* subBuilder, FunctionDeclBase* decl, bool emitBody = true) { - auto outerGeneric = emitOuterGenerics(subContext, decl, decl); + IRGeneric* outerGeneric = nullptr; + + if (auto derivativeRequirement = as<DerivativeRequirementDecl>(decl)) + outerGeneric = emitOuterGenerics(subContext, derivativeRequirement->originalRequirementDecl, derivativeRequirement->originalRequirementDecl); + else + outerGeneric = emitOuterGenerics(subContext, decl, decl); // need to create an IR function here |
