From b2ca2d5a4efeae807d3c3f48f60235e47413b559 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 23 Aug 2024 21:45:59 -0700 Subject: Make variadic generics work with interfaces and forward autodiff. (#4905) --- source/slang/slang-check-decl.cpp | 44 +++++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 11 deletions(-) (limited to 'source/slang/slang-check-decl.cpp') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index c27e0c6f0..66707fc56 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -3931,7 +3931,7 @@ namespace Slang { // Our synthesized method will have parameters matching the names // and types of those on the requirement, and it will use expressions - // that reference those parametesr as arguments for the call expresison + // that reference those parameters as arguments for the call expresison // that makes up the body. // for (auto paramDeclRef : getParameters(m_astBuilder, requirement)) @@ -3951,14 +3951,6 @@ namespace Slang synParamDecl->parentDecl = synthesized; synthesized->members.add(synParamDecl); - // For each paramter, we will create an argument expression - // for the call in the function body. - // - auto synArg = m_astBuilder->create(); - synArg->declRef = makeDeclRef(synParamDecl); - synArg->type = paramType; - synArgs.add(synArg); - // Add modifiers for (auto modifier : paramDeclRef.getDecl()->modifiers) { @@ -3975,6 +3967,33 @@ namespace Slang addModifier(synParamDecl, clonedModifier); } } + + // Create an expression that references the parameter for use in arguments. + auto synArg = m_astBuilder->create(); + synArg->declRef = makeDeclRef(synParamDecl); + synArg->type = paramType; + + if (auto typePack = as(paramType)) + { + // If paramType is a concrete type pack, we want to expand it out into + // individual arguments. + for (Index i = 0; i < typePack->getTypeCount(); i++) + { + auto elementType = typePack->getElementType(i); + auto synMemberExpr = m_astBuilder->create(); + synMemberExpr->base = synArg; + synMemberExpr->elementIndices.add((UInt)i); + synMemberExpr->type = elementType; + synArgs.add(synMemberExpr); + } + } + else + { + // For ordinary non-pack paramters, we will use synArg directly to + // referencing the parameter for the call in the function body. + // + synArgs.add(synArg); + } } } @@ -4156,8 +4175,6 @@ namespace Slang addModifier(synFuncDecl, m_astBuilder->create()); synFuncDecl->parentDecl = aggTypeDecl; - SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); - bodyVisitor.registerDifferentiableTypesForFunc(synFuncDecl); } else { @@ -4281,6 +4298,11 @@ namespace Slang // synFuncDecl->parentDecl = context->parentDecl; + // If the synthesized func is differentiable, make sure to populate its + // differential type dictionary. + SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); + bodyVisitor.registerDifferentiableTypesForFunc(synFuncDecl); + // Once our synthesized declaration is complete, we need // to install it as the witness that satifies the given // requirement. -- cgit v1.2.3