diff options
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index fe43a4f8f..4d36299bb 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1174,6 +1174,23 @@ namespace Slang } } + if (auto typePack = as<ConcreteTypePack>(type)) + { + bool anyDifferentiableElement = false; + List<Type*> diffTypes; + for (Index i = 0; i < typePack->getTypeCount(); i++) + { + auto t = typePack->getElementType(i); + auto diffType = tryGetDifferentialType(builder, t); + if (!diffType) + diffType = m_astBuilder->getVoidType(); + else + anyDifferentiableElement = true; + diffTypes.add(diffType); + } + if (anyDifferentiableElement) + return builder->getTypePack(diffTypes.getArrayView()); + } return nullptr; } @@ -1368,6 +1385,13 @@ namespace Slang }); return; } + + if (auto typePack = as<ConcreteTypePack>(type)) + { + for (Index i = 0; i < typePack->getTypeCount(); i++) + maybeRegisterDifferentiableTypeImplRecursive(builder, typePack->getElementType(i)); + return; + } } @@ -2797,6 +2821,36 @@ namespace Slang return modifiedType->getBase(); } + if (auto typePack = as<ConcreteTypePack>(primalType)) + { + // The differential pair of a type pack should be a type pack of differential pairs. + List<Type*> diffTypes; + for (Index i = 0; i < typePack->getTypeCount(); i++) + { + auto t = typePack->getElementType(i); + diffTypes.add(getDifferentialPairType(t)); + } + return m_astBuilder->getTypePack(diffTypes.getArrayView()); + } + else if (isAbstractTypePack(primalType)) + { + // The differential pair of an abstract type pack P should be `expand DifferentialPair<each P>`. + auto eachType = m_astBuilder->getEachType(primalType); + auto diffPairEachType = getDifferentialPairType(eachType); + if (auto expandType = as<ExpandType>(primalType)) + { + List<Type*> capturedTypePacks; + for (Index i = 0; i < expandType->getCapturedTypePackCount(); i++) + { + capturedTypePacks.add(expandType->getCapturedTypePack(i)); + } + return m_astBuilder->getExpandType(diffPairEachType, capturedTypePacks.getArrayView()); + } + else + { + return m_astBuilder->getExpandType(diffPairEachType, makeArrayViewSingle(primalType)); + } + } // Get a reference to the builtin 'IDifferentiable' interface auto differentiableInterface = getASTBuilder()->getDifferentiableInterfaceType(); @@ -3598,6 +3652,10 @@ namespace Slang if (!isTypePack(baseType) && !as<TupleType>(baseType)) goto error; } + + if (auto tupleType = as<TupleType>(baseType)) + baseType = tupleType->getTypePack(); + { SLANG_ASSERT(m_capturedTypePacks); if (auto baseExpandType = as<ExpandType>(baseType)) |
