summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-23 21:45:59 -0700
committerGitHub <noreply@github.com>2024-08-23 21:45:59 -0700
commitb2ca2d5a4efeae807d3c3f48f60235e47413b559 (patch)
tree643d2bab5776e5f8f7cfa722975af9e826d77c9d /source/slang/slang-check-decl.cpp
parente4088cd602bd4d5a72fea67a787b1319acfc044d (diff)
Make variadic generics work with interfaces and forward autodiff. (#4905)
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp44
1 files changed, 33 insertions, 11 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index c27e0c6f0..66707fc56 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -3931,7 +3931,7 @@ namespace Slang
{
// Our synthesized method will have parameters matching the names
// and types of those on the requirement, and it will use expressions
- // that reference those parametesr as arguments for the call expresison
+ // that reference those parameters as arguments for the call expresison
// that makes up the body.
//
for (auto paramDeclRef : getParameters(m_astBuilder, requirement))
@@ -3951,14 +3951,6 @@ namespace Slang
synParamDecl->parentDecl = synthesized;
synthesized->members.add(synParamDecl);
- // For each paramter, we will create an argument expression
- // for the call in the function body.
- //
- auto synArg = m_astBuilder->create<VarExpr>();
- synArg->declRef = makeDeclRef(synParamDecl);
- synArg->type = paramType;
- synArgs.add(synArg);
-
// Add modifiers
for (auto modifier : paramDeclRef.getDecl()->modifiers)
{
@@ -3975,6 +3967,33 @@ namespace Slang
addModifier(synParamDecl, clonedModifier);
}
}
+
+ // Create an expression that references the parameter for use in arguments.
+ auto synArg = m_astBuilder->create<VarExpr>();
+ synArg->declRef = makeDeclRef(synParamDecl);
+ synArg->type = paramType;
+
+ if (auto typePack = as<ConcreteTypePack>(paramType))
+ {
+ // If paramType is a concrete type pack, we want to expand it out into
+ // individual arguments.
+ for (Index i = 0; i < typePack->getTypeCount(); i++)
+ {
+ auto elementType = typePack->getElementType(i);
+ auto synMemberExpr = m_astBuilder->create<SwizzleExpr>();
+ synMemberExpr->base = synArg;
+ synMemberExpr->elementIndices.add((UInt)i);
+ synMemberExpr->type = elementType;
+ synArgs.add(synMemberExpr);
+ }
+ }
+ else
+ {
+ // For ordinary non-pack paramters, we will use synArg directly to
+ // referencing the parameter for the call in the function body.
+ //
+ synArgs.add(synArg);
+ }
}
}
@@ -4156,8 +4175,6 @@ namespace Slang
addModifier(synFuncDecl, m_astBuilder->create<ForceInlineAttribute>());
synFuncDecl->parentDecl = aggTypeDecl;
- SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl));
- bodyVisitor.registerDifferentiableTypesForFunc(synFuncDecl);
}
else
{
@@ -4281,6 +4298,11 @@ namespace Slang
//
synFuncDecl->parentDecl = context->parentDecl;
+ // If the synthesized func is differentiable, make sure to populate its
+ // differential type dictionary.
+ SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl));
+ bodyVisitor.registerDifferentiableTypesForFunc(synFuncDecl);
+
// Once our synthesized declaration is complete, we need
// to install it as the witness that satifies the given
// requirement.