summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-23 21:45:59 -0700
committerGitHub <noreply@github.com>2024-08-23 21:45:59 -0700
commitb2ca2d5a4efeae807d3c3f48f60235e47413b559 (patch)
tree643d2bab5776e5f8f7cfa722975af9e826d77c9d
parente4088cd602bd4d5a72fea67a787b1319acfc044d (diff)
Make variadic generics work with interfaces and forward autodiff. (#4905)
-rw-r--r--source/slang/core.meta.slang24
-rw-r--r--source/slang/slang-check-decl.cpp44
-rw-r--r--source/slang/slang-check-expr.cpp58
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp43
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h4
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp191
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h4
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h87
-rw-r--r--source/slang/slang-ir-autodiff.cpp156
-rw-r--r--source/slang/slang-ir-autodiff.h7
-rw-r--r--source/slang/slang-ir-inst-defs.h6
-rw-r--r--source/slang/slang-ir-lower-expand-type.cpp167
-rw-r--r--source/slang/slang-ir-lower-expand-type.h30
-rw-r--r--source/slang/slang-ir-peephole.cpp1
-rw-r--r--source/slang/slang-ir-specialize.cpp352
-rw-r--r--source/slang/slang-ir-util.cpp1
-rw-r--r--source/slang/slang-ir.cpp12
-rw-r--r--source/slang/slang-ir.h10
-rw-r--r--source/slang/slang-lower-to-ir.cpp10
-rw-r--r--source/slang/slang-mangle.cpp42
-rw-r--r--tests/language-feature/generics/variadic-0.slang4
-rw-r--r--tests/language-feature/generics/variadic-void.slang2
-rw-r--r--tests/language-feature/ifunc/diff-functor.slang44
-rw-r--r--tests/language-feature/ifunc/ifunc.slang40
24 files changed, 967 insertions, 372 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 84e1b8168..0b57993ef 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -953,6 +953,30 @@ extension Tuple<T> : IComparable
}
}
+interface IMutatingFunc<TR, each TP>
+{
+ [mutating]
+ TR __call(expand each TP p);
+}
+
+interface IFunc<TR, each TP> : IMutatingFunc<TR, expand each TP>
+{
+ TR __call(expand each TP p);
+}
+
+interface IDifferentiableMutatingFunc<TR : IDifferentiable, each TP : IDifferentiable> : IMutatingFunc<TR, expand each TP>
+{
+ [Differentiable]
+ [mutating]
+ TR __call(expand each TP p);
+}
+
+interface IDifferentiableFunc<TR : IDifferentiable, each TP : IDifferentiable> : IFunc<TR, expand each TP>, IDifferentiableMutatingFunc<TR, expand each TP>
+{
+ [Differentiable]
+ TR __call(expand each TP p);
+}
+
__generic<T>
__magic_type(NativeRefType)
__intrinsic_type($(kIROp_NativePtrType))
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index c27e0c6f0..66707fc56 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -3931,7 +3931,7 @@ namespace Slang
{
// Our synthesized method will have parameters matching the names
// and types of those on the requirement, and it will use expressions
- // that reference those parametesr as arguments for the call expresison
+ // that reference those parameters as arguments for the call expresison
// that makes up the body.
//
for (auto paramDeclRef : getParameters(m_astBuilder, requirement))
@@ -3951,14 +3951,6 @@ namespace Slang
synParamDecl->parentDecl = synthesized;
synthesized->members.add(synParamDecl);
- // For each paramter, we will create an argument expression
- // for the call in the function body.
- //
- auto synArg = m_astBuilder->create<VarExpr>();
- synArg->declRef = makeDeclRef(synParamDecl);
- synArg->type = paramType;
- synArgs.add(synArg);
-
// Add modifiers
for (auto modifier : paramDeclRef.getDecl()->modifiers)
{
@@ -3975,6 +3967,33 @@ namespace Slang
addModifier(synParamDecl, clonedModifier);
}
}
+
+ // Create an expression that references the parameter for use in arguments.
+ auto synArg = m_astBuilder->create<VarExpr>();
+ synArg->declRef = makeDeclRef(synParamDecl);
+ synArg->type = paramType;
+
+ if (auto typePack = as<ConcreteTypePack>(paramType))
+ {
+ // If paramType is a concrete type pack, we want to expand it out into
+ // individual arguments.
+ for (Index i = 0; i < typePack->getTypeCount(); i++)
+ {
+ auto elementType = typePack->getElementType(i);
+ auto synMemberExpr = m_astBuilder->create<SwizzleExpr>();
+ synMemberExpr->base = synArg;
+ synMemberExpr->elementIndices.add((UInt)i);
+ synMemberExpr->type = elementType;
+ synArgs.add(synMemberExpr);
+ }
+ }
+ else
+ {
+ // For ordinary non-pack paramters, we will use synArg directly to
+ // referencing the parameter for the call in the function body.
+ //
+ synArgs.add(synArg);
+ }
}
}
@@ -4156,8 +4175,6 @@ namespace Slang
addModifier(synFuncDecl, m_astBuilder->create<ForceInlineAttribute>());
synFuncDecl->parentDecl = aggTypeDecl;
- SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl));
- bodyVisitor.registerDifferentiableTypesForFunc(synFuncDecl);
}
else
{
@@ -4281,6 +4298,11 @@ namespace Slang
//
synFuncDecl->parentDecl = context->parentDecl;
+ // If the synthesized func is differentiable, make sure to populate its
+ // differential type dictionary.
+ SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl));
+ bodyVisitor.registerDifferentiableTypesForFunc(synFuncDecl);
+
// Once our synthesized declaration is complete, we need
// to install it as the witness that satifies the given
// requirement.
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index fe43a4f8f..4d36299bb 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1174,6 +1174,23 @@ namespace Slang
}
}
+ if (auto typePack = as<ConcreteTypePack>(type))
+ {
+ bool anyDifferentiableElement = false;
+ List<Type*> diffTypes;
+ for (Index i = 0; i < typePack->getTypeCount(); i++)
+ {
+ auto t = typePack->getElementType(i);
+ auto diffType = tryGetDifferentialType(builder, t);
+ if (!diffType)
+ diffType = m_astBuilder->getVoidType();
+ else
+ anyDifferentiableElement = true;
+ diffTypes.add(diffType);
+ }
+ if (anyDifferentiableElement)
+ return builder->getTypePack(diffTypes.getArrayView());
+ }
return nullptr;
}
@@ -1368,6 +1385,13 @@ namespace Slang
});
return;
}
+
+ if (auto typePack = as<ConcreteTypePack>(type))
+ {
+ for (Index i = 0; i < typePack->getTypeCount(); i++)
+ maybeRegisterDifferentiableTypeImplRecursive(builder, typePack->getElementType(i));
+ return;
+ }
}
@@ -2797,6 +2821,36 @@ namespace Slang
return modifiedType->getBase();
}
+ if (auto typePack = as<ConcreteTypePack>(primalType))
+ {
+ // The differential pair of a type pack should be a type pack of differential pairs.
+ List<Type*> diffTypes;
+ for (Index i = 0; i < typePack->getTypeCount(); i++)
+ {
+ auto t = typePack->getElementType(i);
+ diffTypes.add(getDifferentialPairType(t));
+ }
+ return m_astBuilder->getTypePack(diffTypes.getArrayView());
+ }
+ else if (isAbstractTypePack(primalType))
+ {
+ // The differential pair of an abstract type pack P should be `expand DifferentialPair<each P>`.
+ auto eachType = m_astBuilder->getEachType(primalType);
+ auto diffPairEachType = getDifferentialPairType(eachType);
+ if (auto expandType = as<ExpandType>(primalType))
+ {
+ List<Type*> capturedTypePacks;
+ for (Index i = 0; i < expandType->getCapturedTypePackCount(); i++)
+ {
+ capturedTypePacks.add(expandType->getCapturedTypePack(i));
+ }
+ return m_astBuilder->getExpandType(diffPairEachType, capturedTypePacks.getArrayView());
+ }
+ else
+ {
+ return m_astBuilder->getExpandType(diffPairEachType, makeArrayViewSingle(primalType));
+ }
+ }
// Get a reference to the builtin 'IDifferentiable' interface
auto differentiableInterface = getASTBuilder()->getDifferentiableInterfaceType();
@@ -3598,6 +3652,10 @@ namespace Slang
if (!isTypePack(baseType) && !as<TupleType>(baseType))
goto error;
}
+
+ if (auto tupleType = as<TupleType>(baseType))
+ baseType = tupleType->getTypePack();
+
{
SLANG_ASSERT(m_capturedTypePacks);
if (auto baseExpandType = as<ExpandType>(baseType))
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 9adbe42d5..91d3e71cb 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -444,6 +444,10 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst*
operandDataType = (IRType*)findOrTranscribePrimalInst(builder, operandDataType);
diffOperands.add(getDifferentialZeroOfType(builder, operandDataType));
}
+ else
+ {
+ diffOperands.add(builder->getVoidValue());
+ }
}
}
@@ -1110,6 +1114,39 @@ InstPair ForwardDiffTranscriber::transcribeGetElement(IRBuilder* builder, IRInst
return InstPair(primalGetElementPtr, diffGetElementPtr);
}
+InstPair ForwardDiffTranscriber::transcribeGetTupleElement(IRBuilder* builder, IRInst* originalInst)
+{
+ IRInst* origBase = originalInst->getOperand(0);
+ auto primalBase = findOrTranscribePrimalInst(builder, origBase);
+ auto primalIndex = originalInst->getOperand(1);
+
+ auto primalType = (IRType*)findOrTranscribePrimalInst(builder, originalInst->getDataType());
+
+ IRInst* primalOperands[] = { primalBase, primalIndex };
+ IRInst* primalGetElement = builder->emitIntrinsicInst(
+ primalType,
+ originalInst->getOp(),
+ 2,
+ primalOperands);
+
+ IRInst* diffGetElement = nullptr;
+
+ if (auto diffType = differentiateType(builder, primalGetElement->getDataType()))
+ {
+ if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
+ {
+ IRInst* diffOperands[] = { diffBase, primalIndex };
+ diffGetElement = builder->emitIntrinsicInst(
+ diffType,
+ originalInst->getOp(),
+ 2,
+ diffOperands);
+ }
+ }
+
+ return InstPair(primalGetElement, diffGetElement);
+}
+
InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst)
{
auto updateInst = as<IRUpdateElement>(originalInst);
@@ -1792,6 +1829,8 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_MakeVectorFromScalar:
case kIROp_MakeArray:
case kIROp_MakeArrayFromElement:
+ case kIROp_MakeTuple:
+ case kIROp_MakeValuePack:
return transcribeConstruct(builder, origInst);
case kIROp_MakeStruct:
return transcribeMakeStruct(builder, origInst);
@@ -1805,7 +1844,6 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_swizzle:
return transcribeSwizzle(builder, as<IRSwizzle>(origInst));
- case kIROp_MakeTuple:
case kIROp_Neg:
return transcribeByPassthrough(builder, origInst);
@@ -1832,6 +1870,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_GetElementPtr:
return transcribeGetElement(builder, origInst);
+ case kIROp_GetTupleElement:
+ return transcribeGetTupleElement(builder, origInst);
+
case kIROp_ifElse:
return transcribeIfElse(builder, as<IRIfElse>(origInst));
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index f88235558..f2659777d 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -48,6 +48,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct);
InstPair transcribeMakeStruct(IRBuilder* builder, IRInst* origMakeStruct);
+ InstPair transcribeMakeTuple(IRBuilder* builder, IRInst* origMakeTuple);
+
// Differentiating a call instruction here is primarily about generating
// an appropriate call list based on whichever parameters have differentials
// in the current transcription context.
@@ -68,6 +70,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr);
+ InstPair transcribeGetTupleElement(IRBuilder* builder, IRInst* origInst);
+
InstPair transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst);
InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index a1fa5f21a..da69ed8ae 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -174,179 +174,9 @@ IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRI
IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey);
-// Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
-IRWitnessTable* AutoDiffTranscriberBase::getDifferentialPairWitness(IRBuilder* builder, IRInst* inOriginalDiffPairType, IRInst* inPrimalDiffPairType)
-{
- // Differentiate the pair type to get it's differential (which is itself a pair)
- auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)inOriginalDiffPairType);
-
- auto addMethod = builder->createFunc();
- auto zeroMethod = builder->createFunc();
-
- auto table = builder->createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, (IRType*)inPrimalDiffPairType);
-
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType);
- builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeWitnessStructKey, table);
- builder->createWitnessTableEntry(table, autoDiffSharedContext->addMethodStructKey, addMethod);
- builder->createWitnessTableEntry(table, autoDiffSharedContext->zeroMethodStructKey, zeroMethod);
-
- bool isUserCodeType = as<IRDifferentialPairUserCodeType>(inOriginalDiffPairType) ? true : false;
-
- // Fill in differential method implementations.
- auto elementType = as<IRDifferentialPairTypeBase>(inPrimalDiffPairType)->getValueType();
- auto innerWitness = as<IRDifferentialPairTypeBase>(inPrimalDiffPairType)->getWitness();
-
- {
- // Add method.
- IRBuilder b = *builder;
- b.setInsertInto(addMethod);
- b.addBackwardDifferentiableDecoration(addMethod);
- IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType };
- addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType));
- b.emitBlock();
- auto p0 = b.emitParam(diffDiffPairType);
- auto p1 = b.emitParam(diffDiffPairType);
-
- // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type.
- auto innerAdd = _lookupWitness(&b, innerWitness, autoDiffSharedContext->addMethodStructKey);
- IRInst* argsPrimal[2] = {
- isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0),
- isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) };
- auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal);
- IRInst* argsDiff[2] = {
- isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0),
- isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)};
- auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff);
- auto retVal =
- isUserCodeType
- ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart)
- : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart);
- b.emitReturn(retVal);
- }
- {
- // Zero method.
- IRBuilder b = *builder;
- b.setInsertInto(zeroMethod);
- zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType));
- b.emitBlock();
- auto innerZero = _lookupWitness(&b, innerWitness, autoDiffSharedContext->zeroMethodStructKey);
- auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr);
- auto retVal =
- isUserCodeType
- ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal)
- : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal);
- b.emitReturn(retVal);
- }
-
- // Record this in the context for future lookups
- differentiableTypeConformanceContext.differentiableWitnessDictionary[(IRType*)inOriginalDiffPairType] = table;
-
- return table;
-}
-
-// Get or construct `:IDifferentiable` conformance for an Array.
-IRWitnessTable* AutoDiffTranscriberBase::getArrayWitness(IRBuilder* builder, IRInst* inOriginalArrayType, IRInst* inPrimalArrayType)
-{
- // Differentiate the pair type to get it's differential (which is itself a pair)
- auto diffArrayType = (IRType*)differentiateType(builder, (IRType*)inOriginalArrayType);
-
- if (!diffArrayType)
- return nullptr;
-
- auto innerWitness = tryGetDifferentiableWitness(builder, as<IRArrayTypeBase>(inOriginalArrayType)->getElementType());
-
- auto addMethod = builder->createFunc();
- auto zeroMethod = builder->createFunc();
-
- auto table = builder->createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, (IRType*)inPrimalArrayType);
-
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffArrayType);
- builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeWitnessStructKey, table);
- builder->createWitnessTableEntry(table, autoDiffSharedContext->addMethodStructKey, addMethod);
- builder->createWitnessTableEntry(table, autoDiffSharedContext->zeroMethodStructKey, zeroMethod);
-
- auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType();
-
- // Fill in differential method implementations.
- {
- // Add method.
- IRBuilder b = *builder;
- b.setInsertInto(addMethod);
- b.addBackwardDifferentiableDecoration(addMethod);
- IRType* paramTypes[2] = { diffArrayType, diffArrayType };
- addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType));
- b.emitBlock();
- auto p0 = b.emitParam(diffArrayType);
- auto p1 = b.emitParam(diffArrayType);
-
- // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type.
- auto innerAdd = _lookupWitness(&b, innerWitness, autoDiffSharedContext->addMethodStructKey);
- auto resultVar = b.emitVar(diffArrayType);
- IRBlock* loopBodyBlock = nullptr;
- IRBlock* loopBreakBlock = nullptr;
- auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as<IRArrayTypeBase>(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock);
- b.setInsertBefore(loopBodyBlock->getTerminator());
-
- IRInst* args[2] = {
- b.emitElementExtract(p0, loopCounter),
- b.emitElementExtract(p1, loopCounter) };
- auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args);
- auto addr = b.emitElementAddress(resultVar, loopCounter);
- b.emitStore(addr, elementResult);
- b.setInsertInto(loopBreakBlock);
- b.emitReturn(b.emitLoad(resultVar));
- }
- {
- // Zero method.
- IRBuilder b = *builder;
- b.setInsertInto(zeroMethod);
- zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType));
- b.emitBlock();
-
- auto innerZero = _lookupWitness(&b, innerWitness, autoDiffSharedContext->zeroMethodStructKey);
- auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr);
- auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal);
- b.emitReturn(retVal);
- }
-
- // Record this in the context for future lookups
- differentiableTypeConformanceContext.differentiableWitnessDictionary[(IRType*)inOriginalArrayType] = table;
-
- return table;
-}
-
IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType)
{
- if (isNoDiffType((IRType*)originalType))
- return nullptr;
-
- IRInst* witness =
- differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)originalType);
- if (witness)
- {
- witness = lookupPrimalInst(builder, witness, nullptr);
- SLANG_RELEASE_ASSERT(witness || as<IRArrayType>(originalType));
- }
- if (!witness)
- {
- auto primalType = lookupPrimalInst(builder, originalType, nullptr);
- SLANG_RELEASE_ASSERT(primalType);
- if (auto primalPairType = as<IRDifferentialPairTypeBase>(primalType))
- {
- witness = getDifferentialPairWitness(builder, originalType, primalPairType);
- }
- else if (auto arrayType = as<IRArrayType>(primalType))
- {
- witness = getArrayWitness(builder, originalType, arrayType);
- }
- else if (auto extractExistential = as<IRExtractExistentialType>(originalType))
- {
- differentiateExtractExistentialType(builder, extractExistential, witness);
- }
- }
- return witness;
+ return differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType);
}
IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness)
@@ -486,15 +316,20 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
}
case kIROp_TupleType:
+ case kIROp_TypePack:
{
- auto tupleType = as<IRTupleType>(primalType);
List<IRType*> diffTypeList;
- // TODO: what if we have type parameters here?
- for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
- diffTypeList.add(
- differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
-
- return builder->getTupleType(diffTypeList);
+ for (UIndex ii = 0; ii < primalType->getOperandCount(); ii++)
+ {
+ auto diffElementType = differentiateType(builder, (IRType*)primalType->getOperand(ii));
+ if (!diffElementType)
+ diffElementType = builder->getVoidType();
+ diffTypeList.add(diffElementType);
+ }
+ if (primalType->getOp() == kIROp_TupleType)
+ return builder->getTupleType(diffTypeList);
+ else
+ return builder->getTypePack((UInt)diffTypeList.getCount(), diffTypeList.getBuffer());
}
default:
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
index f672631e3..f7f2dd6f2 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.h
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -91,10 +91,6 @@ struct AutoDiffTranscriberBase
void maybeMigrateDifferentiableDictionaryFromDerivativeFunc(IRBuilder* builder, IRInst* origFunc);
- // Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
- IRWitnessTable* getDifferentialPairWitness(IRBuilder* builder, IRInst* inOriginalDiffPairType, IRInst* inPrimalDiffPairType);
- IRWitnessTable* getArrayWitness(IRBuilder* builder, IRInst* inOriginalArrayType, IRInst* inPrimalArrayType);
-
IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType);
IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness);
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 05884d13d..f8f6b03ab 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1486,6 +1486,9 @@ struct DiffTransposePass
return transposeMakeStruct(builder, fwdInst, revValue);
case kIROp_MakeArray:
return transposeMakeArray(builder, fwdInst, revValue);
+ case kIROp_MakeTuple:
+ case kIROp_MakeValuePack:
+ return transposeMakeTuple(builder, fwdInst, revValue);
case kIROp_MakeArrayFromElement:
return transposeMakeArrayFromElement(builder, fwdInst, revValue);
@@ -1898,6 +1901,29 @@ struct DiffTransposePass
return TranspositionResult(gradients);
}
+ TranspositionResult transposeMakeTuple(IRBuilder* builder, IRInst* fwdMakeTuple, IRInst* revValue)
+ {
+ List<RevGradient> gradients;
+ auto type = fwdMakeTuple->getDataType();
+ for (UInt ii = 0; ii < type->getOperandCount(); ii++)
+ {
+ auto elementType = (IRType*)type->getOperand(ii);
+ auto gradAtField = builder->emitGetTupleElement(
+ elementType,
+ revValue,
+ ii);
+ SLANG_RELEASE_ASSERT(ii < fwdMakeTuple->getOperandCount());
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMakeTuple->getOperand(ii),
+ gradAtField,
+ fwdMakeTuple));
+ }
+
+ // (A = MakeTuple(F1, F2, F3)) -> [(dF1 += dA.F1), (dF2 += dA.F2), (dF3 += dA.F3)]
+ return TranspositionResult(gradients);
+ }
+
TranspositionResult transposeMakeStruct(IRBuilder* builder, IRInst* fwdMakeStruct, IRInst* revValue)
{
List<RevGradient> gradients;
@@ -2429,25 +2455,38 @@ struct DiffTransposePass
auto baseType = firstFwdSwizzleInst->getBase()->getDataType();
IRIntegerValue elementCount = 0;
- IRType* elementType = nullptr;
- IRType* primalElementType = nullptr;
+ List<IRType*> elementTypes;
+ List<IRType*> primalElementTypes;
bool isVectorType = false;
-
+ bool isTupleType = false;
if (auto vectorType = as<IRVectorType>(baseType))
{
IRInst* elementCountInst = vectorType->getElementCount();
- elementType = vectorType->getElementType();
- primalElementType = as<IRVectorType>(aggPrimalType)->getElementType();
- SLANG_ASSERT(as<IRIntLit>(elementCountInst));
elementCount = as<IRIntLit>(elementCountInst)->getValue();
+ for (IRIntegerValue i = 0; i < elementCount; i++)
+ {
+ elementTypes.add(vectorType->getElementType());
+ primalElementTypes.add(as<IRVectorType>(aggPrimalType)->getElementType());
+ }
+ SLANG_ASSERT(as<IRIntLit>(elementCountInst));
isVectorType = true;
}
else if (auto basicType = as<IRBasicType>(baseType))
{
- elementType = basicType;
- primalElementType = aggPrimalType;
+ elementTypes.add(basicType);
+ primalElementTypes.add(aggPrimalType);
elementCount = 1;
}
+ else if (as<IRTupleType>(baseType) || as<IRTypePack>(baseType))
+ {
+ isTupleType = true;
+ elementCount = baseType->getOperandCount();
+ for (UInt i = 0; i < baseType->getOperandCount(); i++)
+ {
+ elementTypes.add((IRType*)baseType->getOperand(i));
+ primalElementTypes.add((IRType*)(aggPrimalType->getOperand(i)));
+ }
+ }
else
{
SLANG_UNREACHABLE("unknown operand type of swizzle.");
@@ -2456,18 +2495,22 @@ struct DiffTransposePass
IRInst* targetInst = firstGradient.targetInst;
// Make a list of zeros of the base type.
- auto zeroElement = emitDZeroOfDiffInstType(builder, primalElementType);
List<IRInst*> elementGrads;
+ List<IRInst*> zeroElements;
for (Index i = 0; i < elementCount; ++i)
+ {
+ auto zeroElement = emitDZeroOfDiffInstType(builder, primalElementTypes[i]);
elementGrads.add(zeroElement);
+ zeroElements.add(zeroElement);
+ }
auto accGrad = [&](UIndex i, IRInst* grad)
{
- if (elementGrads[i] == zeroElement)
+ if (elementGrads[i] == zeroElements[i])
elementGrads[i] = grad;
else
- elementGrads[i] = emitDAddOfDiffInstType(builder, primalElementType, elementGrads[i], grad);
+ elementGrads[i] = emitDAddOfDiffInstType(builder, primalElementTypes[i], elementGrads[i], grad);
};
for (auto gradient : gradients)
@@ -2493,12 +2536,19 @@ struct DiffTransposePass
else if (isVectorType)
accGrad((UIndex)targetIndex,
builder->emitElementExtract(
- elementType,
+ elementTypes[(UIndex)targetIndex],
gradient.revGradInst,
builder->getIntValue(
builder->getIntType(),
sourceIndex)));
- // Case 3: Swizzled input is a scalar.
+ // Case 3: swizzled output is a tuple.
+ else if (isTupleType)
+ accGrad((UIndex)targetIndex,
+ builder->emitGetTupleElement(
+ elementTypes[(UIndex)targetIndex],
+ gradient.revGradInst,
+ (UInt)sourceIndex));
+ // Case 4: Swizzled input is a scalar.
else
accGrad((UIndex)targetIndex, gradient.revGradInst);
}
@@ -2509,6 +2559,17 @@ struct DiffTransposePass
targetInst,
builder->emitMakeVector(baseType, (UInt)elementCount, elementGrads.getBuffer()),
nullptr);
+ else if (isTupleType)
+ {
+ return RevGradient(
+ targetInst,
+ builder->emitIntrinsicInst(
+ baseType,
+ baseType->getOp()==kIROp_TupleType ? kIROp_MakeTuple : kIROp_MakeValuePack,
+ (UInt)elementCount,
+ elementGrads.getBuffer()),
+ nullptr);
+ }
else
return RevGradient(
targetInst,
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 6b275179c..b7c2037e5 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -44,6 +44,13 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK
return entry->getRequirementVal();
}
}
+ else if (as<IRMakeWitnessPack>(witness))
+ {
+ // We are looking up a witness from a type pack.
+ // This is only allowed if we are looking up a differential type.
+ // We should turn this into an actual witness table for the type pack/tuple type.
+ SLANG_UNEXPECTED("looking up from a witness pack is invalid and should have been lowered.");
+ }
else
{
return builder->emitLookupInterfaceMethodInst(
@@ -434,10 +441,33 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
}
else
{
- differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness());
+ auto witness = item->getWitness();
// Also register the type's differential type with the same witness.
+ auto concreteType = item->getConcreteType();
IRBuilder subBuilder(item->getConcreteType());
+ if (as<IRTypePack>(concreteType) || as<IRTupleType>(concreteType))
+ {
+ // For tuple types, register the differential type for each element, but don't register for the
+ // tuple/typepack itself.
+ auto witnessPack = as<IRMakeWitnessPack>(witness);
+ SLANG_ASSERT(witnessPack);
+
+ for (UInt i = 0; i < concreteType->getOperandCount(); i++)
+ {
+ auto element = concreteType->getOperand(i);
+ auto elementWitness = witnessPack->getOperand(i);
+ differentiableWitnessDictionary.addIfNotExists(
+ (IRType*)element,
+ _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocTypeStructKey));
+ }
+ return;
+ }
+ else
+ {
+ differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness());
+ }
+
if (!as<IRInterfaceType>(item->getConcreteType()))
{
differentiableWitnessDictionary.addIfNotExists(
@@ -768,16 +798,18 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build
SLANG_UNIMPLEMENTED_X("Impl");
}
+ case kIROp_TypePack:
case kIROp_TupleType:
{
- auto tupleType = as<IRTupleType>(primalType);
List<IRType*> diffTypeList;
// TODO: what if we have type parameters here?
- for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
+ for (UIndex ii = 0; ii < primalType->getOperandCount(); ii++)
diffTypeList.add(
- differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
-
- return builder->getTupleType(diffTypeList);
+ differentiateType(builder, (IRType*)primalType->getOperand(ii)));
+ if (primalType->getOp() == kIROp_TupleType)
+ return builder->getTupleType(diffTypeList);
+ else
+ return builder->getTypePack((UInt)diffTypeList.getCount(), diffTypeList.getBuffer());
}
default:
@@ -795,6 +827,12 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuil
{
SLANG_RELEASE_ASSERT(witness || as<IRArrayType>(primalType));
}
+ if (as<IRMakeWitnessPack>(witness))
+ {
+ // If registered witness is a witness pack for a type pack,
+ // we should reconstruct the true witness table.
+ witness = nullptr;
+ }
if (!witness)
{
@@ -811,6 +849,14 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuil
{
witness = getExtractExistensialTypeWitness(builder, extractExistential);
}
+ else if (auto typePack = as<IRTypePack>(primalType))
+ {
+ witness = getTupleWitness(builder, typePack);
+ }
+ else if (auto tupleType = as<IRTupleType>(primalType))
+ {
+ witness = getTupleWitness(builder, tupleType);
+ }
}
return witness;
}
@@ -963,6 +1009,104 @@ IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder
return table;
}
+IRInst* DifferentiableTypeConformanceContext::getTupleWitness(IRBuilder* builder, IRInst* inTupleType)
+{
+ // Differentiate the pair type to get it's differential (which is itself a pair)
+ auto diffTupleType = (IRType*)differentiateType(builder, (IRType*)inTupleType);
+
+ if (!diffTupleType)
+ return nullptr;
+
+ auto addMethod = builder->createFunc();
+ auto zeroMethod = builder->createFunc();
+
+ auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)inTupleType);
+
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffTupleType);
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table);
+ builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+
+ // Fill in differential method implementations.
+ {
+ // Add method.
+ IRBuilder b = *builder;
+ b.setInsertInto(addMethod);
+ b.addBackwardDifferentiableDecoration(addMethod);
+ IRType* paramTypes[2] = { diffTupleType, diffTupleType };
+ addMethod->setFullType(b.getFuncType(2, paramTypes, diffTupleType));
+ b.emitBlock();
+ auto p0 = b.emitParam(diffTupleType);
+ auto p1 = b.emitParam(diffTupleType);
+ List<IRInst*> results;
+ for (UInt i = 0; i < inTupleType->getOperandCount(); i++)
+ {
+ auto elementType = inTupleType->getOperand(i);
+ auto diffElementType = (IRType*)diffTupleType->getOperand(i);
+ auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType);
+ IRInst* elementResult = nullptr;
+ if (!innerWitness)
+ {
+ elementResult = b.getVoidValue();
+ }
+ else
+ {
+ auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey);
+ auto iVal = b.getIntValue(b.getIntType(), i);
+ IRInst* args[2] = {
+ b.emitGetTupleElement(diffElementType, p0, iVal),
+ b.emitGetTupleElement(diffElementType, p1, iVal) };
+ elementResult = b.emitCallInst(diffElementType, innerAdd, 2, args);
+ }
+ results.add(elementResult);
+ }
+ IRInst* resultVal = nullptr;
+ if (diffTupleType->getOp() == kIROp_TupleType)
+ resultVal = b.emitMakeTuple(diffTupleType, results);
+ else
+ resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer());
+ b.emitReturn(resultVal);
+ }
+ {
+ // Zero method.
+ IRBuilder b = *builder;
+ b.setInsertInto(addMethod);
+ b.addBackwardDifferentiableDecoration(addMethod);
+ addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType));
+ b.emitBlock();
+ List<IRInst*> results;
+ for (UInt i = 0; i < inTupleType->getOperandCount(); i++)
+ {
+ auto elementType = inTupleType->getOperand(i);
+ auto diffElementType = (IRType*)diffTupleType->getOperand(i);
+ auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType);
+ IRInst* elementResult = nullptr;
+ if (!innerWitness)
+ {
+ elementResult = b.getVoidValue();
+ }
+ else
+ {
+ auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey);
+ elementResult = b.emitCallInst(diffElementType, innerZero, 0, nullptr);
+ }
+ results.add(elementResult);
+ }
+ IRInst* resultVal = nullptr;
+ if (diffTupleType->getOp() == kIROp_TupleType)
+ resultVal = b.emitMakeTuple(diffTupleType, results);
+ else
+ resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer());
+ b.emitReturn(resultVal);
+ }
+
+ // Record this in the context for future lookups
+ differentiableWitnessDictionary[(IRType*)inTupleType] = table;
+
+ return table;
+}
+
IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness(
IRBuilder* builder,
IRExtractExistentialType* extractExistentialType)
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index d8f0373ac..23ae717be 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -191,6 +191,8 @@ struct DifferentiableTypeConformanceContext
IRInst* getArrayWitness(IRBuilder* builder, IRArrayType* pairType);
+ IRInst* getTupleWitness(IRBuilder* builder, IRInst* tupleType);
+
IRInst* getExtractExistensialTypeWitness(IRBuilder* builder, IRExtractExistentialType* extractExistentialType);
IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness);
@@ -240,6 +242,11 @@ struct DifferentiableTypeConformanceContext
diffElementType,
as<IRArrayType>(origType)->getElementCount());
}
+ case kIROp_TupleType:
+ case kIROp_TypePack:
+ {
+ return differentiateType(builder, origType);
+ }
case kIROp_DifferentialPairUserCodeType:
{
auto diffPairType = as<IRDifferentialPairTypeBase>(origType);
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 80c810620..179ed3065 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -243,9 +243,11 @@ INST(AssociatedType, associated_type, 0, HOISTABLE)
INST(ThisType, this_type, 0, HOISTABLE)
INST(RTTIType, rtti_type, 0, HOISTABLE)
INST(RTTIHandleType, rtti_handle_type, 0, HOISTABLE)
-INST(TupleType, tuple_type, 0, HOISTABLE)
+/*TupleTypeBase*/
+ INST(TupleType, tuple_type, 0, HOISTABLE)
+ INST(TypePack, TypePack, 0, HOISTABLE)
+INST_RANGE(TupleTypeBase, TupleType, TypePack)
INST(TargetTupleType, TargetTuple, 0, HOISTABLE)
-INST(TypePack, TypePack, 0, HOISTABLE)
INST(ExpandTypeOrVal, ExpandTypeOrVal, 1, HOISTABLE)
// A type that identifies it's contained type as being emittable as `spirv_literal.
diff --git a/source/slang/slang-ir-lower-expand-type.cpp b/source/slang/slang-ir-lower-expand-type.cpp
new file mode 100644
index 000000000..8b68b1fc1
--- /dev/null
+++ b/source/slang/slang-ir-lower-expand-type.cpp
@@ -0,0 +1,167 @@
+#include "slang-ir-lower-expand-type.h"
+
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir-clone.h"
+
+namespace Slang
+{
+ IRInst* clonePatternVal(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, IRInst* eachIndex);
+
+ IRInst* clonePatternValImpl(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, IRInst* eachIndex)
+ {
+ if (!val)
+ return val;
+
+ switch (val->getOp())
+ {
+ case kIROp_ExpandTypeOrVal:
+ return val;
+ case kIROp_Each:
+ {
+ auto eachInst = as<IREach>(val);
+ auto packInst = eachInst->getElement();
+ packInst = clonePatternValImpl(cloneEnv, builder, packInst, eachIndex);
+ auto result = builder->emitGetTupleElement(val->getFullType(), packInst, eachIndex);
+ return result;
+ }
+ case kIROp_Specialize:
+ case kIROp_LookupWitness:
+ case kIROp_ExtractExistentialType:
+ case kIROp_ExtractExistentialWitnessTable:
+ break;
+ default:
+ // If the value is not a type, and it is not in a block, then it is some global inst
+ // that shouldn't be deep copied into current block, such as a IRFunc.
+ if (!as<IRType>(val) && getBlock(val->getParent()) == nullptr)
+ return val;
+ break;
+ }
+ bool anyChange = false;
+ ShortList<IRInst*> operands;
+ for (UInt i = 0; i < val->getOperandCount(); i++)
+ {
+ auto newOperand = clonePatternVal(cloneEnv, builder, val->getOperand(i), eachIndex);
+ if (newOperand != val->getOperand(i))
+ anyChange = true;
+ operands.add(newOperand);
+ }
+ auto newType = clonePatternVal(cloneEnv, builder, val->getFullType(), eachIndex);
+ if (newType != val->getFullType())
+ anyChange = true;
+ if (!anyChange)
+ return val;
+
+ auto newVal = builder->emitIntrinsicInst((IRType*)newType, val->getOp(), operands.getCount(), operands.getArrayView().getBuffer());
+ if (newVal != val)
+ {
+ cloneInstDecorationsAndChildren(&cloneEnv, builder->getModule(), val, newVal);
+ }
+ return newVal;
+ }
+
+ IRInst* clonePatternVal(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, IRInst* eachIndex)
+ {
+ if (auto clonedVal = cloneEnv.mapOldValToNew.tryGetValue(val))
+ return *clonedVal;
+ cloneEnv.mapOldValToNew[val] = val;
+ auto result = clonePatternValImpl(cloneEnv, builder, val, eachIndex);
+ cloneEnv.mapOldValToNew[val] = result;
+ return result;
+ }
+
+ // Translate a `IRExpandType` into an `IRExpand` where the `PatternType` is defined
+ // inside the `IRExpand` body.
+ //
+ IRInst* lowerExpandTypeImpl(IRExpandType* expandType)
+ {
+ // Turn `IRExpandType` into an `IRExpand` instruction.
+ IRBuilder builder(expandType);
+ builder.setInsertBefore(expandType);
+ List<IRInst*> capturedArgs;
+ IRCloneEnv cloneEnv;
+ for (UInt i = 0; i < expandType->getCaptureCount(); i++)
+ {
+ auto capturedArg = expandType->getCaptureType(i);
+ capturedArgs.add(capturedArg);
+ }
+ auto result = builder.emitExpandInst(expandType->getFullType(), expandType->getCaptureCount(), capturedArgs.getBuffer());
+ builder.setInsertInto(result);
+ builder.emitBlock();
+ auto eachIndex = builder.emitParam(builder.getIntType());
+ auto newPatternType = clonePatternVal(cloneEnv, &builder, expandType->getPatternType(), eachIndex);
+ builder.emitYield(newPatternType);
+ return result;
+ }
+
+ // Process the body of an `IRExpand` instruction, and replace the type of children insts if it
+ // is an `IRExpandType`.
+ //
+ void processExpandVal(IRExpand* expandVal)
+ {
+ IRBuilder builder(expandVal);
+ IRCloneEnv cloneEnv;
+ auto eachIndex = expandVal->getFirstBlock()->getFirstParam();
+ for (auto block : expandVal->getBlocks())
+ {
+ for (auto inst : block->getModifiableChildren())
+ {
+ builder.setInsertBefore(inst);
+ auto newType = clonePatternVal(cloneEnv, &builder, inst->getFullType(), eachIndex);
+ if (newType != inst->getFullType())
+ {
+ inst = builder.replaceOperand(&inst->typeUse, newType);
+ }
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto oldOperand = inst->getOperand(i);
+ if (!oldOperand)
+ continue;
+ if (isChildInstOf(oldOperand, expandVal))
+ continue;
+ auto newOperand = clonePatternVal(cloneEnv, &builder, oldOperand, eachIndex);
+ if (newOperand != inst->getOperand(i))
+ {
+ inst = builder.replaceOperand(inst->getOperands() + i, newOperand);
+ }
+ }
+ }
+ }
+ }
+
+ void lowerExpandType(IRModule* module)
+ {
+ // Use a work list to process all instructions in the module, and lower any `IRExpandType` we see
+ // along the way.
+
+ List<IRInst*> workList;
+ for (auto type : module->getGlobalInsts())
+ {
+ workList.add(type);
+ }
+
+ while (workList.getCount() != 0)
+ {
+ auto inst = workList.getLast();
+ workList.removeLast();
+
+ if (auto expandType = as<IRExpandType>(inst))
+ {
+ inst = lowerExpandTypeImpl(expandType);
+ if (inst != expandType)
+ {
+ expandType->replaceUsesWith(inst);
+ expandType->removeAndDeallocate();
+ }
+ }
+ else if (auto expandVal = as<IRExpand>(inst))
+ {
+ processExpandVal(expandVal);
+ }
+ for (auto child : inst->getChildren())
+ {
+ workList.add(child);
+ }
+ }
+ }
+}
diff --git a/source/slang/slang-ir-lower-expand-type.h b/source/slang/slang-ir-lower-expand-type.h
new file mode 100644
index 000000000..28136e8c0
--- /dev/null
+++ b/source/slang/slang-ir-lower-expand-type.h
@@ -0,0 +1,30 @@
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+
+ // After IR lowering, an `expand each X` type will be defined in the IR as:
+ // %X = ...
+ // %e = IREach(%X)
+ // %expand = IRExpandType(%e)
+ // This form allows our IR deduplication logic to find the deduplicate the same
+ // `exapnd` types into the same IR inst.
+ // However after lowering is done, we no longer need this deduplication service.
+ // But having expand types defined in this form is making it very difficult to
+ // specialize.
+ // This pass runs immediately after IR lowering process for a module (pre-linking)
+ // to turn `IRExpandType` into `IRExpand`, so that the above expand type will be
+ // represented as:
+ // %expand = IRExpand : IRTypeKind
+ // {
+ // %eachIndex = IRParam : int;
+ // %e = ...; // may use %eachIndex.
+ // yield %e;
+ // }
+ //
+ // After this translation, there should be no longer any IRExpandType/IREach instructions
+ // that are alive in the IR. All future passes will only need to deal with IRExpand.
+ //
+ void lowerExpandType(IRModule* module);
+}
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index b5f5edb05..8405f9e78 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -329,6 +329,7 @@ struct PeepholeContext : InstPassBase
case kIROp_MakeTuple:
case kIROp_MakeValuePack:
case kIROp_MakeWitnessPack:
+ case kIROp_TypePack:
{
auto element = inst->getOperand(1);
if (auto intLit = as<IRIntLit>(element))
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index c9e94352e..a56dae025 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -8,6 +8,7 @@
#include "slang-ir-lower-witness-lookup.h"
#include "slang-ir-dce.h"
#include "slang-ir-sccp.h"
+#include "slang-ir-util.h"
#include "../core/slang-performance-profiler.h"
namespace Slang
@@ -85,6 +86,7 @@ struct SpecializationContext
{
case kIROp_GlobalGenericParam:
case kIROp_LookupWitness:
+ case kIROp_GetTupleElement:
return false;
case kIROp_Specialize:
// The `specialize` instruction is a bit sepcial,
@@ -589,9 +591,6 @@ struct SpecializationContext
case kIROp_Expand:
return maybeSpecializeExpand(as<IRExpand>(inst));
- case kIROp_ExpandTypeOrVal:
- return maybeSpecializeExpandTypeOrVal(as<IRExpandType>(inst));
-
case kIROp_GetTupleElement:
return maybeSpecializeFoldableInst(inst);
@@ -605,6 +604,15 @@ struct SpecializationContext
case kIROp_CountOf:
return maybeSpecializeCountOf(inst);
+
+ case kIROp_Func:
+
+ if (tryExpandParameterPack(as<IRFunc>(inst)))
+ {
+ addUsersToWorkList(inst);
+ return true;
+ }
+ return false;
}
}
@@ -1010,6 +1018,9 @@ struct SpecializationContext
workList.removeLast();
workListSet.remove(inst);
+ if (!inst->getParent() && inst->getOp() != kIROp_Module)
+ continue;
+
// For each instruction we process, we want to perform
// a few steps.
//
@@ -1182,11 +1193,8 @@ struct SpecializationContext
auto newWrapExistential = builder.emitWrapExistential(
resultType, newCall, slotOperandCount, slotOperands.getArrayView().getBuffer());
inst->replaceUsesWith(newWrapExistential);
- workList.remove(inst);
inst->removeAndDeallocate();
addUsersToWorkList(newWrapExistential);
-
- workList.remove(wrapExistential);
SLANG_ASSERT(!wrapExistential->hasUses());
wrapExistential->removeAndDeallocate();
return true;
@@ -1209,6 +1217,14 @@ struct SpecializationContext
if (maybeSpecializeBufferLoadCall(inst))
return false;
+ // If any arguments are value packs, we need to flatten them.
+ bool isCalleeFullyExpanded = false;
+ tryExpandParameterPack(as<IRFunc>(inst->getCallee()), &isCalleeFullyExpanded);
+ if (isCalleeFullyExpanded)
+ {
+ inst = tryExpandArgPack((IRCall*)inst);
+ }
+
// We can only specialize a call when the callee function is known.
//
auto calleeFunc = as<IRFunc>(inst->getCallee());
@@ -2402,13 +2418,9 @@ struct SpecializationContext
break;
}
}
- auto type = clonePatternVal(*subEnv, builder, childInst->getFullType(), index);
- for (UInt i = 0; i < childInst->getOperandCount(); i++)
- {
- clonePatternVal(*subEnv, builder, childInst->getOperand(i), index);
- }
auto newInst = cloneInst(subEnv, builder, childInst);
- newInst = builder->replaceOperand(&newInst->typeUse, type);
+ if (newInst != childInst)
+ addToWorkList(newInst);
subEnv->mapOldValToNew[childInst] = newInst;
IRBuilder subBuilder(*builder);
subBuilder.setInsertInto(newInst);
@@ -2419,6 +2431,32 @@ struct SpecializationContext
return newInst;
}
+ // A helper function to emit a MakeWitnessPack, MakeTypePack or MakeValuePack inst from
+ // a collection of elements, dependending on `type`.
+ //
+ IRInst* makeSpecializedPack(IRBuilder& builder, IRType* type, ArrayView<IRInst*> elements)
+ {
+ IRInst* resultPack = nullptr;
+ if (as<IRWitnessTableType>(type))
+ {
+ List<IRType*> types;
+ for (auto element : elements)
+ types.add(element->getDataType());
+ auto newTypePack = builder.getTypePack(elements.getCount(), types.getBuffer());
+ resultPack = builder.emitMakeWitnessPack(newTypePack, elements);
+ }
+ else if (as<IRTypeKind>(type) || as<IRTypeType>(type))
+ {
+ auto newTypePack = builder.getTypePack(elements.getCount(), (IRType* const*)elements.getBuffer());
+ resultPack = newTypePack;
+ }
+ else
+ {
+ resultPack = builder.emitMakeValuePack((UInt)elements.getCount(), elements.getBuffer());
+ }
+ return resultPack;
+ }
+
bool maybeSpecializeExpand(IRExpand* expandInst)
{
if (expandInst->getCaptureCount() == 0)
@@ -2440,44 +2478,57 @@ struct SpecializationContext
}
if (elementCount == 0)
{
- auto resultValuePack = builder.emitMakeValuePack(0, (IRInst*const*)nullptr);
- expandInst->replaceUsesWith(resultValuePack);
+ auto resultPack = makeSpecializedPack(builder, expandInst->getDataType(), elements.getArrayView());
+ expandInst->replaceUsesWith(resultPack);
expandInst->removeAndDeallocate();
- addUsersToWorkList(resultValuePack);
+ addUsersToWorkList(resultPack);
return true;
}
+
+ bool isMultiBlock = as<IRYield>(expandInst->getFirstBlock()->getTerminator()) == nullptr;
for (UInt i = 0; i < elementCount; i++)
{
IRCloneEnv cloneEnv;
- IRBlock* firstBlock = nullptr;
IRBuilder subBuilder = builder;
- for (auto childBlock : expandInst->getBlocks())
+ IRBlock* mergeBlock = nullptr;
+ if (isMultiBlock)
{
- auto newBlock = subBuilder.emitBlock();
- if (!firstBlock)
- firstBlock = newBlock;
- cloneEnv.mapOldValToNew[childBlock] = newBlock;
+ IRBlock* firstBlock = nullptr;
+ for (auto childBlock : expandInst->getBlocks())
+ {
+ auto newBlock = subBuilder.emitBlock();
+ if (!firstBlock)
+ firstBlock = newBlock;
+ cloneEnv.mapOldValToNew[childBlock] = newBlock;
+ }
+
+ builder.emitBranch(firstBlock);
+
+ mergeBlock = subBuilder.emitBlock();
+ builder.setInsertInto(mergeBlock);
}
+
auto indexParam = expandInst->getFirstBlock()->getFirstParam();
SLANG_ASSERT(indexParam);
cloneEnv.mapOldValToNew[indexParam] = subBuilder.getIntValue(subBuilder.getIntType(), i);
- builder.emitBranch(firstBlock);
-
- IRBlock* mergeBlock = subBuilder.emitBlock();
- builder.setInsertInto(mergeBlock);
-
for (auto childBlock : expandInst->getBlocks())
{
- auto newBlock = cloneEnv.mapOldValToNew[childBlock];
- subBuilder.setInsertInto(newBlock);
+ if (isMultiBlock)
+ {
+ auto newBlock = cloneEnv.mapOldValToNew[childBlock];
+ subBuilder.setInsertInto(newBlock);
+ }
for (auto child : childBlock->getChildren())
{
if (as<IRYield>(child))
{
- elements.add(cloneEnv.mapOldValToNew[child->getOperand(0)]);
- subBuilder.emitBranch(mergeBlock);
+ auto currentResult = child->getOperand(0);
+ currentResult = findCloneForOperand(&cloneEnv, currentResult);
+ elements.add(currentResult);
+ if (isMultiBlock)
+ subBuilder.emitBranch(mergeBlock);
continue;
}
specializeExpandChildInst(cloneEnv, &subBuilder, child, i);
@@ -2486,129 +2537,22 @@ struct SpecializationContext
}
}
- auto resultValuePack = builder.emitMakeValuePack((UInt)elements.getCount(), elements.getBuffer());
- auto currentBlock = builder.getBlock();
- for (auto nextInst = expandInst->next; nextInst;)
- {
- auto next = nextInst->next;
- nextInst->insertAtEnd(currentBlock);
- nextInst = next;
- }
- addUsersToWorkList(expandInst);
- expandInst->replaceUsesWith(resultValuePack);
- expandInst->removeAndDeallocate();
- return true;
- }
- IRInst* clonePatternValImpl(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, UInt indexInPack)
- {
- if (!val)
- return val;
-
- switch (val->getOp())
- {
- case kIROp_ExpandTypeOrVal:
- return val;
- case kIROp_Each:
+ IRInst* resultPack = makeSpecializedPack(builder, expandInst->getDataType(), elements.getArrayView());
+ if (isMultiBlock)
{
- auto eachInst = as<IREach>(val);
- auto packInst = eachInst->getElement();
- if (auto typePack = as<IRTypePack>(packInst))
- {
- SLANG_RELEASE_ASSERT(indexInPack < typePack->getOperandCount());
- return typePack->getOperand(indexInPack);
- }
- else if (auto makeValuePack = as<IRMakeValuePack>(packInst))
- {
- SLANG_RELEASE_ASSERT(indexInPack < makeValuePack->getOperandCount());
- return makeValuePack->getOperand(indexInPack);
- }
- else if (!as<IRTypeKind>(packInst->getDataType()))
+ auto currentBlock = builder.getBlock();
+ for (auto nextInst = expandInst->next; nextInst;)
{
- auto type = clonePatternVal(cloneEnv, builder, val, indexInPack);
- return builder->emitGetTupleElement((IRType*)type, packInst, indexInPack);
+ auto next = nextInst->next;
+ nextInst->insertAtEnd(currentBlock);
+ nextInst = next;
}
- return val;
- }
- default:
- break;
- }
- bool anyChange = false;
- ShortList<IRInst*> operands;
- for (UInt i = 0; i < val->getOperandCount(); i++)
- {
- auto newOperand = clonePatternVal(cloneEnv, builder, val->getOperand(i), indexInPack);
- if (newOperand != val->getOperand(i))
- anyChange = true;
- operands.add(newOperand);
- }
- auto newType = clonePatternVal(cloneEnv, builder, val->getFullType(), indexInPack);
- if (newType != val->getFullType())
- anyChange = true;
- if (!anyChange)
- return val;
-
- auto newVal = builder->emitIntrinsicInst((IRType*)newType, val->getOp(), operands.getCount(), operands.getArrayView().getBuffer());
- if (newVal != val)
- {
- cloneInstDecorationsAndChildren(&cloneEnv, module, val, newVal);
- }
- return newVal;
- }
-
- IRInst* clonePatternVal(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, UInt indexInPack)
- {
- if (auto clonedVal = cloneEnv.mapOldValToNew.tryGetValue(val))
- return *clonedVal;
- cloneEnv.mapOldValToNew[val] = val;
- auto result = clonePatternValImpl(cloneEnv, builder, val, indexInPack);
- cloneEnv.mapOldValToNew[val] = result;
- return result;
- }
-
- bool maybeSpecializeExpandTypeOrVal(IRExpandType* expandInst)
- {
- if (expandInst->getCaptureCount() == 0)
- return false;
-
- for (UInt i = 0; i < expandInst->getCaptureCount(); i++)
- {
- if (!as<IRTypePack>(expandInst->getCaptureType(i)))
- return false;
- }
- IRBuilder builder(expandInst);
- builder.setInsertBefore(expandInst);
- List<IRInst*> elements;
- UInt elementCount = 0;
- if (auto firstTypePack = as<IRTypePack>(expandInst->getCaptureType(0)))
- {
- elementCount = firstTypePack->getOperandCount();
- }
- for (UInt i = 0; i < elementCount; i++)
- {
- IRCloneEnv cloneEnv;
- auto element = clonePatternVal(cloneEnv, &builder, expandInst->getPatternType(), i);
- elements.add(element);
}
addUsersToWorkList(expandInst);
- if (as<IRWitnessTableType>(expandInst->getDataType()))
- {
- List<IRType*> types;
- for (auto element : elements)
- types.add(element->getDataType());
- auto newTypePack = builder.getTypePack(elements.getCount(), types.getBuffer());
- auto result = builder.emitMakeWitnessPack(newTypePack, elements.getArrayView());
- expandInst->replaceUsesWith(result);
- expandInst->removeAndDeallocate();
- return true;
- }
- else
- {
- auto newTypePack = builder.getTypePack(elements.getCount(), (IRType*const*)elements.getBuffer());
- expandInst->replaceUsesWith(newTypePack);
- expandInst->removeAndDeallocate();
- return true;
- }
+ expandInst->replaceUsesWith(resultPack);
+ expandInst->removeAndDeallocate();
+ return true;
}
// The handling of specialization for global generic type
@@ -2680,6 +2624,108 @@ struct SpecializationContext
}
}
}
+
+
+ // If `func` has any parameters whose types are `IRTypePack`, then we will expand them
+ // into multiple parameters, so that the function has no parameters of type `IRTypePack`.
+ // returns true if changes are made.
+ // For example, this function turns `int f(TypePack<int, float> v)` into
+ // ```
+ // int f(int v0, float v1)
+ // {
+ // v = MakeValuePack(v0,. v1);
+ // ...
+ // }
+ // ```
+ //
+ bool tryExpandParameterPack(IRFunc* func, bool* outIsFullyExpanded = nullptr)
+ {
+ if (!func)
+ return false;
+ if (outIsFullyExpanded)
+ *outIsFullyExpanded = true;
+ ShortList<IRInst*> params;
+ for (auto param : func->getParams())
+ {
+ if (as<IRTypePack>(param->getDataType()))
+ params.add(param);
+ if (as<IRExpand>(param->getDataType()))
+ {
+ if (outIsFullyExpanded)
+ *outIsFullyExpanded = false;
+ return false;
+ }
+ }
+ if (params.getCount() == 0)
+ return false;
+
+ IRBuilder builder(func);
+ for (auto param : params)
+ {
+ builder.setInsertBefore(param);
+ auto typePack = as<IRTypePack>(param->getDataType());
+ ShortList<IRInst*> newParams;
+ for (UInt i = 0; i < typePack->getOperandCount(); i++)
+ {
+ auto newParam = builder.createParam((IRType*)typePack->getOperand(i));
+ newParam->insertBefore(param);
+ newParams.add(newParam);
+ }
+ setInsertBeforeOrdinaryInst(&builder, param);
+ auto val = builder.emitMakeValuePack(typePack, (UInt)newParams.getCount(), newParams.getArrayView().getBuffer());
+ param->replaceUsesWith(val);
+ param->removeAndDeallocate();
+ addUsersToWorkList(val);
+ }
+
+ fixUpFuncType(func);
+ return true;
+ }
+
+ // If any arguments in a call is a value pack, we will expand them into the argument list,
+ // so that the call has no arguments of type `IRTypePack`.
+ // For example, we will turn `f(MakeValuePack(a, b))` into `f(a, b)`.
+ //
+ IRCall* tryExpandArgPack(IRCall* call)
+ {
+ bool anyArgPack = false;
+ for (UInt i = 0; i < call->getArgCount(); i++)
+ {
+ auto arg = call->getArg(i);
+ if (as<IRTypePack>(arg->getDataType()))
+ {
+ anyArgPack = true;
+ break;
+ }
+ }
+ if (!anyArgPack)
+ return call;
+ IRBuilder builder(call);
+ builder.setInsertBefore(call);
+ List<IRInst*> newArgs;
+ for (UInt i = 0; i < call->getArgCount(); i++)
+ {
+ auto arg = call->getArg(i);
+ if (auto typePack = as<IRTypePack>(arg->getDataType()))
+ {
+ for (UInt elementIndex = 0; elementIndex < typePack->getOperandCount(); elementIndex++)
+ {
+ auto newArg = builder.emitGetTupleElement((IRType*)typePack->getOperand(elementIndex), arg, elementIndex);
+ newArgs.add(newArg);
+ }
+ }
+ else
+ {
+ newArgs.add(arg);
+ }
+ }
+ auto newCall = builder.emitCallInst(call->getFullType(), call->getCallee(), newArgs.getArrayView());
+ call->replaceUsesWith(newCall);
+ call->transferDecorationsTo(newCall);
+ call->removeAndDeallocate();
+ return newCall;
+ }
+
};
bool specializeModule(
@@ -2785,6 +2831,13 @@ IRInst* specializeGenericImpl(
IRBuilder* builder = &builderStorage;
builder->setInsertBefore(genericVal);
+ List<IRInst*> pendingWorkList;
+ SLANG_DEFER
+ (
+ for (Index ii = pendingWorkList.getCount() - 1; ii >= 0; ii--)
+ context->addToWorkList(pendingWorkList[ii]);
+ );
+
// Now we will run through the body of the generic and
// clone each of its instructions into the global scope,
// until we reach a `return` instruction.
@@ -2825,10 +2878,11 @@ IRInst* specializeGenericImpl(
{
if (auto func = as<IRFunc>(specializedVal))
{
+ context->tryExpandParameterPack(func);
simplifyFunc(context->targetProgram, func, IRSimplificationOptions::getFast(context->targetProgram));
}
}
-
+ pendingWorkList.add(specializedVal);
return specializedVal;
}
@@ -2848,7 +2902,7 @@ IRInst* specializeGenericImpl(
//
if (context)
{
- context->addToWorkList(clonedInst);
+ pendingWorkList.add(clonedInst);
}
}
}
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index e030b6d24..817c10ec2 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -1695,6 +1695,7 @@ struct GenericChildrenMigrationContextImpl
case kIROp_ClassType:
case kIROp_Func:
case kIROp_Generic:
+ case kIROp_Expand:
return false;
default:
break;
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index e0769686c..0b0a42617 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -4114,12 +4114,17 @@ namespace Slang
// `getTupleElement(makeTuple(a_0, a_1, ... a_N), i)` then we should
// just return `a_i`, provided that the index is properly in range.
//
- if( auto makeTuple = as<IRMakeTuple>(tuple) )
+ switch(tuple->getOp())
{
- if( element < makeTuple->getOperandCount() )
+ case kIROp_MakeTuple:
+ case kIROp_MakeValuePack:
+ case kIROp_MakeWitnessPack:
+ case kIROp_TypePack:
+ if( element < tuple->getOperandCount() )
{
- return makeTuple->getOperand(element);
+ return tuple->getOperand(element);
}
+ break;
}
return emitGetTupleElement(type, tuple, getIntValue(getIntType(), element));
}
@@ -8345,6 +8350,7 @@ namespace Slang
case kIROp_DifferentialPairGetDifferential:
case kIROp_MakeDifferentialPair:
case kIROp_MakeTuple:
+ case kIROp_MakeValuePack:
case kIROp_GetTupleElement:
case kIROp_StructuredBufferLoad:
case kIROp_RWStructuredBufferLoad:
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index b1c2b001e..4a3a04404 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1933,18 +1933,22 @@ struct IRAttributedType : IRType
IRInst* getAttr() { return getOperand(1); }
};
+struct IRTupleTypeBase : IRType
+{
+ IR_PARENT_ISA(TupleTypeBase)
+};
+
/// Represents a tuple. Tuples are created by `IRMakeTuple` and its elements
/// are accessed via `GetTupleElement(tupleValue, IRIntLit)`.
-struct IRTupleType : IRType
+struct IRTupleType : IRTupleTypeBase
{
IR_LEAF_ISA(TupleType)
};
-
/// Represents a type pack. Type packs behave like tuples, but they have a
/// "flattening" semantics, so that MakeTypePack(MakeTypePack(T1,T2), T3) is
/// MakeTypePack(T1,T2,T3).
-struct IRTypePack : IRType
+struct IRTypePack : IRTupleTypeBase
{
IR_LEAF_ISA(TypePack)
};
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 2828752a0..02c4fae68 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -11,6 +11,7 @@
#include "slang-check.h"
#include "slang-ir-bit-field-accessors.h"
#include "slang-ir-loop-inversion.h"
+#include "slang-ir-lower-expand-type.h"
#include "slang-ir.h"
#include "slang-ir-util.h"
#include "slang-ir-constexpr.h"
@@ -1987,7 +1988,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
}
else
{
- return lowerType(context, type->getTypePack());
+ return context->irBuilder->getTupleType(lowerType(context, type->getTypePack()));
}
}
@@ -11080,6 +11081,13 @@ RefPtr<IRModule> generateIRForTranslationUnit(
// Synthesize some code we want to make sure is inlined and simplified
synthesizeBitFieldAccessors(module);
+ // Lower `IRExpandType` types to use `IRExpand`, where the pattern type
+ // is nested inside the `IRExpand` as its children, instead of being same
+ // level entities as the ExpandType itself.
+ // This will unify the specialization logic for both type and value level
+ // expansion.
+ lowerExpandType(module);
+
// Generate DebugValue insts to store values into debug variables,
// if debug symbols are enabled.
if (context->includeDebugInfo)
diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp
index 7951ddc38..5a0d41f09 100644
--- a/source/slang/slang-mangle.cpp
+++ b/source/slang/slang-mangle.cpp
@@ -608,6 +608,48 @@ namespace Slang
{
emitType(context, getResultType(context->astBuilder, callableDeclRef));
}
+
+ // Include key modifiers in the mangled name so we never deduplicate
+ // things like a nonmutating method and a mutating method.
+ bool isMutating = false;
+ bool isRefThis = false;
+ bool isFwdDiff = false;
+ bool isBwdDiff = false;
+ bool isNoDiffThis = false;
+ for (auto modifier : callableDeclRef.getDecl()->modifiers)
+ {
+ if (as<MutatingAttribute>(modifier))
+ {
+ isMutating = true;
+ }
+ else if (as<RefAttribute>(modifier))
+ {
+ isRefThis = true;
+ }
+ else if (as<ForwardDifferentiableAttribute>(modifier))
+ {
+ isFwdDiff = true;
+ }
+ else if (as<BackwardDifferentiableAttribute>(modifier))
+ {
+ isBwdDiff = true;
+ }
+ else if (as<NoDiffThisAttribute>(modifier))
+ {
+ isNoDiffThis = true;
+ }
+ }
+
+ if (isMutating)
+ emitRaw(context, "m");
+ if (isRefThis)
+ emitRaw(context, "r");
+ if (isFwdDiff)
+ emitRaw(context, "f");
+ if (isBwdDiff)
+ emitRaw(context, "b");
+ if (isNoDiffThis)
+ emitRaw(context, "n");
}
}
diff --git a/tests/language-feature/generics/variadic-0.slang b/tests/language-feature/generics/variadic-0.slang
index 8ee41647f..ac9ca2c1c 100644
--- a/tests/language-feature/generics/variadic-0.slang
+++ b/tests/language-feature/generics/variadic-0.slang
@@ -1,4 +1,6 @@
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
@@ -19,7 +21,7 @@ S<T> makeS<T:__BuiltinArithmeticType>(T x)
}
bool cmp<T:__BuiltinArithmeticType>(T a, int b)
{
- return a > __int_cast<T>(b);
+ return a > T(b);
}
void accept<each T>(expand each T value) {}
diff --git a/tests/language-feature/generics/variadic-void.slang b/tests/language-feature/generics/variadic-void.slang
index d44acbfd4..976c104f8 100644
--- a/tests/language-feature/generics/variadic-void.slang
+++ b/tests/language-feature/generics/variadic-void.slang
@@ -1,4 +1,6 @@
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -shaderobj -output-using-type
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
diff --git a/tests/language-feature/ifunc/diff-functor.slang b/tests/language-feature/ifunc/diff-functor.slang
new file mode 100644
index 000000000..04b0be44f
--- /dev/null
+++ b/tests/language-feature/ifunc/diff-functor.slang
@@ -0,0 +1,44 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -use-dxil -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -shaderobj -output-using-type
+
+struct DiffFunctor : IDifferentiableFunc<float, float>
+{
+ [Differentiable]
+ float __call(float p)
+ {
+ return p + 1;
+ }
+}
+
+float apply(IMutatingFunc<float, float> f, float p)
+{
+ return f.__call(p);
+}
+
+[Differentiable]
+float applyDiff(IDifferentiableFunc<float, float> f, float p)
+{
+ return f.__call(p);
+}
+
+[Differentiable]
+TR applyDiffGen<TR : IDifferentiable, each TP : IDifferentiable>(IDifferentiableFunc<TR, TP> f, expand each TP p)
+{
+ return f.__call(expand each p);
+}
+
+//TEST_INPUT:ubuffer(data=[0 3 2 2], stride=4):out,name=outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint tid: SV_DispatchThreadID)
+{
+ // CHECK: 4
+ outputBuffer[0] = (uint)apply(DiffFunctor(), 3.0);
+ // CHECK: 1
+ outputBuffer[1] = (uint)fwd_diff(applyDiff)(DiffFunctor(), diffPair(2.0, 1.0)).d;
+ // CHECK: 1
+ outputBuffer[2] = (uint)fwd_diff(applyDiffGen<float, float>)(DiffFunctor(), diffPair(2.0, 1.0)).d;
+}
diff --git a/tests/language-feature/ifunc/ifunc.slang b/tests/language-feature/ifunc/ifunc.slang
new file mode 100644
index 000000000..f270299b3
--- /dev/null
+++ b/tests/language-feature/ifunc/ifunc.slang
@@ -0,0 +1,40 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -use-dxil -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -shaderobj -output-using-type
+
+struct Functor : IFunc<int, int, bool>
+{
+ int __call(int p, bool t)
+ {
+ return p + 1;
+ }
+}
+
+struct MutatingFunctor : IMutatingFunc<int, int, bool>
+{
+ int data = 0;
+ [mutating]
+ int __call(int p, bool t)
+ {
+ data++;
+ return p + 1;
+ }
+}
+
+int apply(IMutatingFunc<int, int, bool> f, int p)
+{
+ return f.__call(p, true);
+}
+
+//TEST_INPUT:ubuffer(data=[0 3 2 2], stride=4):out,name=outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint tid: SV_DispatchThreadID)
+{
+ // CHECK: 2
+ outputBuffer[0] = apply(MutatingFunctor(), 1);
+ // CHECK: 3
+ outputBuffer[1] = apply(Functor(), 2);
+}