From b2ca2d5a4efeae807d3c3f48f60235e47413b559 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 23 Aug 2024 21:45:59 -0700 Subject: Make variadic generics work with interfaces and forward autodiff. (#4905) --- source/slang/slang-check-expr.cpp | 58 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) (limited to 'source/slang/slang-check-expr.cpp') diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index fe43a4f8f..4d36299bb 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1174,6 +1174,23 @@ namespace Slang } } + if (auto typePack = as(type)) + { + bool anyDifferentiableElement = false; + List diffTypes; + for (Index i = 0; i < typePack->getTypeCount(); i++) + { + auto t = typePack->getElementType(i); + auto diffType = tryGetDifferentialType(builder, t); + if (!diffType) + diffType = m_astBuilder->getVoidType(); + else + anyDifferentiableElement = true; + diffTypes.add(diffType); + } + if (anyDifferentiableElement) + return builder->getTypePack(diffTypes.getArrayView()); + } return nullptr; } @@ -1368,6 +1385,13 @@ namespace Slang }); return; } + + if (auto typePack = as(type)) + { + for (Index i = 0; i < typePack->getTypeCount(); i++) + maybeRegisterDifferentiableTypeImplRecursive(builder, typePack->getElementType(i)); + return; + } } @@ -2797,6 +2821,36 @@ namespace Slang return modifiedType->getBase(); } + if (auto typePack = as(primalType)) + { + // The differential pair of a type pack should be a type pack of differential pairs. + List diffTypes; + for (Index i = 0; i < typePack->getTypeCount(); i++) + { + auto t = typePack->getElementType(i); + diffTypes.add(getDifferentialPairType(t)); + } + return m_astBuilder->getTypePack(diffTypes.getArrayView()); + } + else if (isAbstractTypePack(primalType)) + { + // The differential pair of an abstract type pack P should be `expand DifferentialPair`. + auto eachType = m_astBuilder->getEachType(primalType); + auto diffPairEachType = getDifferentialPairType(eachType); + if (auto expandType = as(primalType)) + { + List capturedTypePacks; + for (Index i = 0; i < expandType->getCapturedTypePackCount(); i++) + { + capturedTypePacks.add(expandType->getCapturedTypePack(i)); + } + return m_astBuilder->getExpandType(diffPairEachType, capturedTypePacks.getArrayView()); + } + else + { + return m_astBuilder->getExpandType(diffPairEachType, makeArrayViewSingle(primalType)); + } + } // Get a reference to the builtin 'IDifferentiable' interface auto differentiableInterface = getASTBuilder()->getDifferentiableInterfaceType(); @@ -3598,6 +3652,10 @@ namespace Slang if (!isTypePack(baseType) && !as(baseType)) goto error; } + + if (auto tupleType = as(baseType)) + baseType = tupleType->getTypePack(); + { SLANG_ASSERT(m_capturedTypePacks); if (auto baseExpandType = as(baseType)) -- cgit v1.2.3