summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-23 21:45:59 -0700
committerGitHub <noreply@github.com>2024-08-23 21:45:59 -0700
commitb2ca2d5a4efeae807d3c3f48f60235e47413b559 (patch)
tree643d2bab5776e5f8f7cfa722975af9e826d77c9d /source/slang/slang-check-expr.cpp
parente4088cd602bd4d5a72fea67a787b1319acfc044d (diff)
Make variadic generics work with interfaces and forward autodiff. (#4905)
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp58
1 files changed, 58 insertions, 0 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index fe43a4f8f..4d36299bb 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1174,6 +1174,23 @@ namespace Slang
}
}
+ if (auto typePack = as<ConcreteTypePack>(type))
+ {
+ bool anyDifferentiableElement = false;
+ List<Type*> diffTypes;
+ for (Index i = 0; i < typePack->getTypeCount(); i++)
+ {
+ auto t = typePack->getElementType(i);
+ auto diffType = tryGetDifferentialType(builder, t);
+ if (!diffType)
+ diffType = m_astBuilder->getVoidType();
+ else
+ anyDifferentiableElement = true;
+ diffTypes.add(diffType);
+ }
+ if (anyDifferentiableElement)
+ return builder->getTypePack(diffTypes.getArrayView());
+ }
return nullptr;
}
@@ -1368,6 +1385,13 @@ namespace Slang
});
return;
}
+
+ if (auto typePack = as<ConcreteTypePack>(type))
+ {
+ for (Index i = 0; i < typePack->getTypeCount(); i++)
+ maybeRegisterDifferentiableTypeImplRecursive(builder, typePack->getElementType(i));
+ return;
+ }
}
@@ -2797,6 +2821,36 @@ namespace Slang
return modifiedType->getBase();
}
+ if (auto typePack = as<ConcreteTypePack>(primalType))
+ {
+ // The differential pair of a type pack should be a type pack of differential pairs.
+ List<Type*> diffTypes;
+ for (Index i = 0; i < typePack->getTypeCount(); i++)
+ {
+ auto t = typePack->getElementType(i);
+ diffTypes.add(getDifferentialPairType(t));
+ }
+ return m_astBuilder->getTypePack(diffTypes.getArrayView());
+ }
+ else if (isAbstractTypePack(primalType))
+ {
+ // The differential pair of an abstract type pack P should be `expand DifferentialPair<each P>`.
+ auto eachType = m_astBuilder->getEachType(primalType);
+ auto diffPairEachType = getDifferentialPairType(eachType);
+ if (auto expandType = as<ExpandType>(primalType))
+ {
+ List<Type*> capturedTypePacks;
+ for (Index i = 0; i < expandType->getCapturedTypePackCount(); i++)
+ {
+ capturedTypePacks.add(expandType->getCapturedTypePack(i));
+ }
+ return m_astBuilder->getExpandType(diffPairEachType, capturedTypePacks.getArrayView());
+ }
+ else
+ {
+ return m_astBuilder->getExpandType(diffPairEachType, makeArrayViewSingle(primalType));
+ }
+ }
// Get a reference to the builtin 'IDifferentiable' interface
auto differentiableInterface = getASTBuilder()->getDifferentiableInterfaceType();
@@ -3598,6 +3652,10 @@ namespace Slang
if (!isTypePack(baseType) && !as<TupleType>(baseType))
goto error;
}
+
+ if (auto tupleType = as<TupleType>(baseType))
+ baseType = tupleType->getTypePack();
+
{
SLANG_ASSERT(m_capturedTypePacks);
if (auto baseExpandType = as<ExpandType>(baseType))