summaryrefslogtreecommitdiff
path: root/source/slang/slang-lower-to-ir.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-26 17:37:04 -0700
committerGitHub <noreply@github.com>2023-04-26 17:37:04 -0700
commitfc54adee1f7f0ba18591fc84ce5d51ac23afa954 (patch)
tree4727ed6109ac50e95c49aadcebc0fb8b95495739 /source/slang/slang-lower-to-ir.cpp
parent61eb17b0b556ccc06f65f921bb0a4ea2784c4e20 (diff)
Autodiff support for dynamically dispatched generic method. (#2846)
* Autodiff support for dynamically dispatched generic method. * Fix. * Support dynamically dispatched generic type. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
-rw-r--r--source/slang/slang-lower-to-ir.cpp14
1 files changed, 12 insertions, 2 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index d644d01c7..c8a41c7c7 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -7429,7 +7429,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
else
{
- if (auto callableDecl = as<CallableDecl>(requirementDecl))
+ CallableDecl* callableDecl = nullptr;
+ if (auto genDecl = as<GenericDecl>(requirementDecl))
+ callableDecl = as<CallableDecl>(genDecl->inner);
+ else
+ callableDecl = as<CallableDecl>(requirementDecl);
+ if (callableDecl)
{
// Differentiable functions has additional requirements for the derivatives.
for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementReferenceDecl>())
@@ -8369,7 +8374,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
LoweredValInfo lowerFuncDeclInContext(IRGenContext* subContext, IRBuilder* subBuilder, FunctionDeclBase* decl, bool emitBody = true)
{
- auto outerGeneric = emitOuterGenerics(subContext, decl, decl);
+ IRGeneric* outerGeneric = nullptr;
+
+ if (auto derivativeRequirement = as<DerivativeRequirementDecl>(decl))
+ outerGeneric = emitOuterGenerics(subContext, derivativeRequirement->originalRequirementDecl, derivativeRequirement->originalRequirementDecl);
+ else
+ outerGeneric = emitOuterGenerics(subContext, decl, decl);
// need to create an IR function here