summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2025-05-29 16:36:49 -0700
committerGitHub <noreply@github.com>2025-05-29 16:36:49 -0700
commit984d7f22f8a0909dc870c65bb927094c54f55402 (patch)
treeab255bf44e14f6cbaa09522f90b12464f1c6a339 /source/slang
parentf4d7954e088966c2ae8618b1cc17aac4d64ef013 (diff)
Implement MapElement for CoopMat (#7159)
With this PR, MapElement works for the following signatures: - CoopMat<...>::MapElement(functype(...)); - CoopMat<...>::MapElement(capturing-lambda); - CoopMat<...>::MapElement(not-capturing-lambda); - Tuple<CoopMat<...>,...>::MapElement(functype(...)); - Tuple<CoopMat<...>,...>::MapElement(capturing-lambda); - Tuple<CoopMat<...>,...>::MapElement(not-capturing-lambda);
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/hlsl.meta.slang45
-rw-r--r--source/slang/slang-ast-type.cpp27
-rw-r--r--source/slang/slang-check-constraint.cpp243
-rw-r--r--source/slang/slang-emit-spirv-ops.h26
-rw-r--r--source/slang/slang-emit-spirv.cpp50
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-insts.h14
-rw-r--r--source/slang/slang-ir-legalize-types.cpp33
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp79
-rw-r--r--source/slang/slang-ir.cpp6
-rw-r--r--source/slang/slang-lower-to-ir.cpp10
11 files changed, 531 insertions, 4 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 801c61481..677b5d7bf 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -23415,6 +23415,23 @@ struct CoopMat
};
}
+ __intrinsic_op($(kIROp_CoopMatMapElementIFunc))
+ internal static This __MapElement<
+ TOperator,
+ TFunc : IFunc<T, uint32_t, uint32_t, T>
+ >(
+ This coopMat,
+ TOperator mapOp,
+ TFunc mapObj
+ );
+
+ This MapElement<
+ TFunc : IFunc<T, uint32_t, uint32_t, T>
+ >(TFunc mapOp)
+ {
+ return __MapElement(this, mapOp.operator(), mapOp);
+ }
+
//
// Store
//
@@ -24276,8 +24293,36 @@ CoopMat<T, S, M, N, CoopMatMatrixUse.MatrixAccumulator> coopMatMulAdd<
};
}
+extension<
+ T : __BuiltinArithmeticType,
+ let S : MemoryScope,
+ let M : int,
+ let N : int,
+ let R : linalg.CoopMatMatrixUse,
+ each Ts : __BuiltinArithmeticType
+> Tuple<linalg.CoopMat<T, S, M, N, R>, expand linalg.CoopMat<each Ts, S, M, N, R>>
+{
+ __intrinsic_op($(kIROp_CoopMatMapElementIFunc))
+ CoopMat<T, S, M, N, R> MapElement(functype(uint32_t, uint32_t, T, expand each Ts)->T mapOp);
+
+ __intrinsic_op($(kIROp_CoopMatMapElementIFunc))
+ static CoopMat<T, S, M, N, R> __MapElement<
+ TOperator,
+ TFunc : IFunc<T, uint32_t, uint32_t, T, expand each Ts>
+ >(This tuple, TOperator mapOp, TFunc mapObj);
+
+ [ForceInline]
+ CoopMat<T, S, M, N, R> MapElement<
+ TFunc : IFunc<T, uint32_t, uint32_t, T, expand each Ts>
+ >(TFunc mapOp)
+ {
+ return __MapElement(this, mapOp.operator(), mapOp);
+ }
+};
+
} // namespace linalg
+
//
// Cooperative Vector
//
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 18f3f90bc..7cfc6a67a 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -578,7 +578,19 @@ Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet s
List<Type*> substParamTypes;
for (Index pp = 0; pp < getParamCount(); pp++)
{
- substParamTypes.add(as<Type>(getParamType(pp)->substituteImpl(astBuilder, subst, &diff)));
+ auto substParamType = as<Type>(getParamType(pp)->substituteImpl(astBuilder, subst, &diff));
+ if (auto typePack = as<ConcreteTypePack>(substParamType))
+ {
+ // Unwrap the ConcreteTypePack and add each element as a parameter
+ for (Index i = 0; i < typePack->getTypeCount(); ++i)
+ {
+ substParamTypes.add(typePack->getElementType(i));
+ }
+ }
+ else
+ {
+ substParamTypes.add(substParamType);
+ }
}
// early exit for no change...
@@ -774,7 +786,18 @@ Val* ConcreteTypePack::_substituteImplOverride(
for (Index i = 0; i < getTypeCount(); i++)
{
auto substType = as<Type>(getElementType(i)->substituteImpl(astBuilder, subst, &diff));
- substElementTypes.add(substType);
+ if (auto typePack = as<ConcreteTypePack>(substType))
+ {
+ // Unwrap the ConcreteTypePack and add each element as a parameter
+ for (Index j = 0; j < typePack->getTypeCount(); ++j)
+ {
+ substElementTypes.add(typePack->getElementType(j));
+ }
+ }
+ else
+ {
+ substElementTypes.add(substType);
+ }
}
if (!diff)
return this;
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 6f9191135..6259e2fb8 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -1210,6 +1210,191 @@ void SemanticsVisitor::maybeUnifyUnconstraintIntParam(
constraints.constraints.add(c);
}
+struct IndexSpan
+{
+ Index index;
+ Index count;
+
+ IndexSpan()
+ : index(0), count(0)
+ {
+ }
+ IndexSpan(Index idx, Index cnt)
+ : index(idx), count(cnt)
+ {
+ }
+};
+
+struct IndexSpanPair
+{
+ IndexSpan first;
+ IndexSpan second;
+
+ IndexSpanPair() {}
+ IndexSpanPair(IndexSpan f, IndexSpan s)
+ : first(f), second(s)
+ {
+ }
+};
+
+// Helper function to unwrap a type and count expandable types
+static void unwrapTypeAndCountExpandable(
+ Type* type,
+ ShortList<Type*>& outTypes,
+ int& outExpandableCount)
+{
+ if (auto concretePack = as<ConcreteTypePack>(type))
+ {
+ for (Index i = 0; i < concretePack->getTypeCount(); ++i)
+ {
+ outTypes.add(concretePack->getElementType(i));
+ if (isAbstractTypePack(concretePack->getElementType(i)))
+ outExpandableCount++;
+ }
+ }
+ else if (isAbstractTypePack(type))
+ {
+ outTypes.add(type);
+ outExpandableCount++;
+ }
+}
+
+// Helper function to map type arguments between two types, handling expandable types
+static bool matchTypeArgMapping(
+ Type* firstType,
+ Type* secondType,
+ ShortList<Type*>& outFlattenedFirst,
+ ShortList<Type*>& outFlattenedSecond,
+ ShortList<IndexSpanPair>& outMapping)
+{
+ // Unwrap and flatten the types
+ ShortList<Type*>& firstTypes = outFlattenedFirst;
+ ShortList<Type*>& secondTypes = outFlattenedSecond;
+
+ // Count expandable types as we unwrap
+ int firstExpandableCount = 0;
+ int secondExpandableCount = 0;
+
+ // Unwrap both types using the helper function
+ unwrapTypeAndCountExpandable(firstType, firstTypes, firstExpandableCount);
+ unwrapTypeAndCountExpandable(secondType, secondTypes, secondExpandableCount);
+
+ // We need to figure out which side should be expanding.
+ // Consider the following cases,
+ //
+ // left = [ expand, expand ]
+ // right = [ int, float, expand ]
+ // when one side has more non-expandable types, the other side should expand to match it.
+ // in this case, "left" should expand to cover "int" and "float".
+ //
+ // left = [ int, float, expand, expand ]
+ // right = [ int, float, expand ]
+ // when the number of the non-expandable types are same, we want to expand side that has
+ // fewer expandable types. In this case, "right" should expand to cover the first "expand".
+ //
+ // left = ConcreteTypePack(ExpandType, ExpandType)
+ // right = ConcreteTypePack(int, bool, float, double).
+ // In this case, we shouldn't be mapping the first ExpandType to int and the second
+ // ExpandType to bool, float, double. Instead, they should evenly divide the second type
+ // pack, so we map first ExpandType with int, bool, and second ExpandType to float, double.
+ //
+ int firstCount = (int)firstTypes.getCount();
+ int secondCount = (int)secondTypes.getCount();
+ int countDifference =
+ (firstCount - firstExpandableCount) - (secondCount - secondExpandableCount);
+
+ bool shouldExpandFirst =
+ (firstExpandableCount > 0) &&
+ ((countDifference < 0) ||
+ (countDifference == 0 && firstExpandableCount < secondExpandableCount));
+
+ bool shouldExpandSecond =
+ (secondExpandableCount > 0) &&
+ ((countDifference > 0) ||
+ (countDifference == 0 && firstExpandableCount > secondExpandableCount));
+
+ // We need to figure out how much types should match per each expandable type.
+ int typesPerExpand = 0;
+ if (shouldExpandSecond)
+ {
+ // More types on first, need to expand second
+ int countToMatch = countDifference + firstExpandableCount;
+ SLANG_ASSERT(secondExpandableCount != 0);
+ if (countToMatch % secondExpandableCount != 0)
+ return false;
+ typesPerExpand = countToMatch / secondExpandableCount;
+ }
+ else if (shouldExpandFirst)
+ {
+ // More types on second, need to expand first
+ int countToMatch = -countDifference + secondExpandableCount;
+ SLANG_ASSERT(firstExpandableCount != 0);
+ if (countToMatch % firstExpandableCount != 0)
+ return false;
+ typesPerExpand = countToMatch / firstExpandableCount;
+ }
+ // If countDifference == 0, no expansion needed
+
+ // Generate the mapping
+ Index firstIndex = 0;
+ Index secondIndex = 0;
+
+ while (firstIndex < firstCount && secondIndex < secondCount)
+ {
+ IndexSpanPair mapping;
+
+ // Determine spans based on expandable types and count difference
+ if (shouldExpandFirst)
+ {
+ // Expanding first to match second
+ if (isAbstractTypePack(firstTypes[firstIndex]))
+ {
+ mapping.first = IndexSpan(firstIndex, 1);
+ mapping.second = IndexSpan(secondIndex, typesPerExpand);
+ secondIndex += typesPerExpand;
+ }
+ else
+ {
+ mapping.first = IndexSpan(firstIndex, 1);
+ mapping.second = IndexSpan(secondIndex, 1);
+ secondIndex++;
+ }
+ firstIndex++;
+ }
+ else if (shouldExpandSecond)
+ {
+ // Expanding second to match first
+ if (isAbstractTypePack(secondTypes[secondIndex]))
+ {
+ mapping.first = IndexSpan(firstIndex, typesPerExpand);
+ mapping.second = IndexSpan(secondIndex, 1);
+ firstIndex += typesPerExpand;
+ }
+ else
+ {
+ mapping.first = IndexSpan(firstIndex, 1);
+ mapping.second = IndexSpan(secondIndex, 1);
+ firstIndex++;
+ }
+ secondIndex++;
+ }
+ else
+ {
+ // No expansion needed
+ mapping.first = IndexSpan(firstIndex, 1);
+ mapping.second = IndexSpan(secondIndex, 1);
+ firstIndex++;
+ secondIndex++;
+ }
+
+ outMapping.add(mapping);
+ }
+
+ SLANG_ASSERT(!shouldExpandSecond || firstIndex == firstCount);
+ SLANG_ASSERT(!shouldExpandFirst || secondIndex == secondCount);
+ return true;
+}
+
bool SemanticsVisitor::TryUnifyTypes(
ConstraintSystem& constraints,
ValUnificationContext unifyCtx,
@@ -1244,7 +1429,63 @@ bool SemanticsVisitor::TryUnifyTypes(
return TryUnifyConjunctionType(constraints, unifyCtx, fst, snd);
}
- // If one of the types is a type pack, we need to recursively unify the element types.
+ // Unwrap ConcreteTypePack and call TryUnifyTypes recursively.
+ ShortList<IndexSpanPair> typeMapping;
+ ShortList<Type*> flattenedFirst;
+ ShortList<Type*> flattenedSecond;
+ if (matchTypeArgMapping(fst, snd, flattenedFirst, flattenedSecond, typeMapping) &&
+ typeMapping.getCount() > 1)
+ {
+ // Apply unification based on the mapping
+ for (const auto& mapping : typeMapping)
+ {
+ // Make sure it is one of three cases: 1:1, 1:N or N:1
+ SLANG_ASSERT(mapping.first.count > 0 && mapping.second.count > 0);
+ SLANG_ASSERT(mapping.first.count == 1 || mapping.second.count == 1);
+
+ // Helper function to create QualType from mapping span
+ auto mayPackAndGetQualType = [this](
+ const IndexSpan& span,
+ ShortList<Type*>& typeList,
+ bool isLeftValue,
+ Type* otherType) -> QualType
+ {
+ // When the `otherType` is GenericTypePackParamDecl or ExpandType,
+ // we need to create a ConcreteTypePack so that we can handle them
+ // recursively.
+ if (isDeclRefTypeOf<GenericTypePackParamDecl>(otherType) ||
+ as<ExpandType>(otherType))
+ {
+ // Multiple types: create ConcreteTypePack
+ auto typesView = makeArrayView(&typeList[span.index], span.count);
+ auto typePack = m_astBuilder->getTypePack(typesView);
+ return QualType(typePack, isLeftValue);
+ }
+
+ SLANG_ASSERT(span.count == 1);
+ return QualType(typeList[span.index], isLeftValue);
+ };
+
+ // Get the types directly from the mapping
+ QualType firstArg = mayPackAndGetQualType(
+ mapping.first,
+ flattenedFirst,
+ fst.isLeftValue,
+ flattenedSecond[mapping.second.index]);
+ QualType secondArg = mayPackAndGetQualType(
+ mapping.second,
+ flattenedSecond,
+ snd.isLeftValue,
+ flattenedFirst[mapping.first.index]);
+
+ // Perform the unification
+ if (!TryUnifyTypes(constraints, unifyCtx, firstArg, secondArg))
+ return false;
+ }
+
+ return true;
+ }
+
if (auto fstTypePack = as<ConcreteTypePack>(fst))
{
if (auto sndTypePack = as<ConcreteTypePack>(snd))
diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h
index 81ad4bddf..4d33e045e 100644
--- a/source/slang/slang-emit-spirv-ops.h
+++ b/source/slang/slang-emit-spirv-ops.h
@@ -2666,4 +2666,30 @@ SpvInst* emitOpTypeNodePayloadArray(IRInst* inst, const T& type)
type);
}
+// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCooperativeMatrixPerElementOpNV
+template<typename T1, typename T2, typename T3, typename... TOperands>
+SpvInst* emiOpCooperativeMatrixPerElementOp(
+ SpvInstParent* parent,
+ IRInst* inst,
+ const T1& idResultType,
+ const T2& matrix,
+ const T3& func,
+ const TOperands&... operands)
+{
+ static_assert(isSingular<T1>);
+ static_assert(isSingular<T2>);
+ static_assert(isSingular<T3>);
+ // Emit the instruction with a variable number of operands
+ return emitInst(
+ parent,
+ inst,
+ SpvOpCooperativeMatrixPerElementOpNV,
+ idResultType,
+ kResultID,
+ matrix,
+ func,
+ operands...);
+}
+
+
#endif // SLANG_IN_SPIRV_EMIT_CONTEXT
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index ba238985b..57ad1a988 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -4276,6 +4276,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case kIROp_MakeArray:
result = emitConstruct(parent, inst);
break;
+ case kIROp_CoopMatMapElementIFunc:
+ result = emitCoopMatMapElementWithIFunc(parent, as<IRCoopMatMapElementIFunc>(inst));
+ break;
case kIROp_MakeTensorAddressingTensorLayout:
result = emitOpCreateTensorLayout(parent, inst, getID(ensureInst(inst->getDataType())));
break;
@@ -7698,6 +7701,53 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
}
}
+ SpvInst* emitCoopMatMapElementWithIFunc(SpvInstParent* parent, IRCoopMatMapElementIFunc* inst)
+ {
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_cooperative_matrix2"));
+ requireSPIRVCapability(SpvCapabilityCooperativeMatrixPerElementOperationsNV);
+
+ IRInst* matOrTuple = inst->getCoopMat();
+
+ IRInst* mat0 = nullptr;
+
+ UInt tupleCount = 0;
+ IRInst* tuple = as<IRMakeStruct>(matOrTuple);
+ if (tuple)
+ {
+ mat0 = tuple->getOperand(0);
+ tupleCount = tuple->getOperandCount();
+ }
+ else
+ {
+ mat0 = matOrTuple;
+ }
+
+ auto funcCall = inst->getIFuncCall();
+
+ IRInst* ifuncThis = nullptr;
+ if (inst->getOperandCount() > 2)
+ ifuncThis = inst->getIFuncThis();
+
+ return emitInstCustomOperandFunc(
+ parent,
+ inst,
+ SpvOpCooperativeMatrixPerElementOpNV,
+ [&]()
+ {
+ emitOperand(mat0->getDataType());
+ emitOperand(kResultID);
+
+ emitOperand(mat0);
+ emitOperand(funcCall);
+
+ if (ifuncThis)
+ emitOperand(ifuncThis);
+
+ for (UInt i = 1; i < tupleCount; i++)
+ emitOperand(tuple->getOperand(i));
+ });
+ }
+
SpvInst* emitSplat(SpvInstParent* parent, IRInst* inst, IRInst* scalar, IRIntegerValue numElems)
{
const auto scalarTy = as<IRBasicType>(scalar->getDataType());
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 419c8d59d..3b45d46b3 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -789,6 +789,8 @@ INST(AllocateTorchTensor, allocTorchTensor, 0, 0)
INST(TorchGetCudaStream, TorchGetCudaStream, 0, 0)
INST(TorchTensorGetView, TorchTensorGetView, 0, 0)
+INST(CoopMatMapElementIFunc, CoopMatMapElementIFunc, 2, 0)
+
INST(AllocateOpaqueHandle, allocateOpaqueHandle, 0, 0)
// Return the register index thtat a resource is bound to.
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 2ab4db980..2fff4e451 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3124,6 +3124,18 @@ struct IRMakeCoopVector : IRInst
IR_LEAF_ISA(MakeCoopVector)
};
+struct IRCoopMatMapElementIFunc : IRInst
+{
+ IR_LEAF_ISA(CoopMatMapElementIFunc)
+ IRInst* getCoopMat() { return getOperand(0); }
+ IRInst* getTuple() { return getOperand(0); }
+ IRFunc* getIFuncCall() { return as<IRFunc>(getOperand(1)); }
+ IRInst* getIFuncThis() { return getOperand(2); }
+
+ bool hasIFuncThis() { return getOperandCount() > 2; }
+ void setIFuncCall(IRFunc* func) { setOperand(1, func); }
+};
+
// An Instruction that creates a differential pair value from a
// primal and differential.
@@ -4241,6 +4253,8 @@ public:
IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, UInt element);
IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, IRInst* element);
+ IRInst* emitCoopMatMapElementFunc(IRType* type, IRInst* tuple, IRInst* func);
+
IRInst* emitGetElement(IRType* type, IRInst* arrayLikeType, IRIntegerValue element);
IRInst* emitGetElementPtr(IRType* type, IRInst* arrayLikeType, IRIntegerValue element);
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index af29d1998..529770acc 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -2282,6 +2282,36 @@ static LegalVal legalizeGlobalParam(
IRTypeLegalizationContext* context,
IRGlobalParam* irGlobalParam);
+static LegalVal legalizeCoopMatMapElementIFunc(
+ IRTypeLegalizationContext* context,
+ IRCoopMatMapElementIFunc* inst)
+{
+ // When the functor object is a lambda with no captures,
+ // it will be removed as a part of the legalization process,
+ // because it is a zero-sized struct.
+ // We need to explicitly remove it from its user.
+ if (inst->hasIFuncThis())
+ {
+ // Check if `this` is valid.
+ auto legalArg = legalizeOperand(context, inst->getIFuncThis());
+ if (legalArg.flavor == LegalVal::Flavor::none)
+ {
+ // If `this` is not valid, remove it from IR.
+ IRBuilder builder{inst};
+ builder.setInsertBefore(inst);
+ auto newInst = builder.emitCoopMatMapElementFunc(
+ inst->getFullType(),
+ inst->getOperand(0),
+ inst->getOperand(1));
+
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ return LegalVal::simple(newInst);
+ }
+ }
+ return LegalVal::simple(inst);
+}
+
static LegalVal legalizeInst(IRTypeLegalizationContext* context, IRInst* inst)
{
// Any additional instructions we need to emit
@@ -2293,6 +2323,9 @@ static LegalVal legalizeInst(IRTypeLegalizationContext* context, IRInst* inst)
// Special-case certain operations
switch (inst->getOp())
{
+ case kIROp_CoopMatMapElementIFunc:
+ return legalizeCoopMatMapElementIFunc(context, cast<IRCoopMatMapElementIFunc>(inst));
+
case kIROp_Var:
return legalizeLocalVar(context, cast<IRVar>(inst));
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 20b7f795f..31742cbde 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -1506,6 +1506,82 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
}
+ // Creates a new function with a different parameter order.
+ // OpCooperativeMatrixPerElementOpNV expects a callback function to have the parameter
+ // in the following order:
+ // return-type CallbackFunction(uint, uint, coopmat1, functorThis, coopmat2, coopmat3, ...)
+ // But what we have by default is:
+ // return-type CallbackFunction(functorThis, uint, uint, coopmat1, coopmat2, coopmat3, ...)
+ //
+ // The new function will do the following,
+ // return-type newCallback(uint a, uint b, T mat1, TFunc f, T mat2, T mat3, ...)
+ // { return targetFunc(f, a, b, mat1, mat2, mat3, ...); }
+ IRFunc* createWrapperFunctionForPerElement(IRBuilder& builder, IRFunc* targetFunc)
+ {
+ List<IRType*> paramTypes;
+ for (UInt i = 0; i < targetFunc->getParamCount(); i++)
+ {
+ paramTypes.add(targetFunc->getParamType(i));
+ }
+
+ SLANG_ASSERT(paramTypes.getCount() >= 4);
+
+ IRType* tempTypes[4];
+ tempTypes[3] = builder.getPtrType(paramTypes[0]);
+ tempTypes[0] = paramTypes[1];
+ tempTypes[1] = paramTypes[2];
+ tempTypes[2] = paramTypes[3];
+ paramTypes[0] = tempTypes[0];
+ paramTypes[1] = tempTypes[1];
+ paramTypes[2] = tempTypes[2];
+ paramTypes[3] = tempTypes[3];
+
+ IRType* returnType = targetFunc->getDataType()->getResultType();
+
+ IRBuilderInsertLocScope insertLocScope(&builder);
+ auto wrapperFunc = builder.createFunc();
+ builder.setDataType(wrapperFunc, builder.getFuncType(paramTypes, returnType));
+ builder.setInsertInto(wrapperFunc);
+ auto block = builder.emitBlock();
+ builder.setInsertInto(block);
+
+ List<IRInst*> params;
+ for (Index i = 0; i < paramTypes.getCount(); i++)
+ {
+ params.add(builder.emitParam(paramTypes[i]));
+ }
+
+ IRInst* tempParams[4];
+ tempParams[0] = builder.emitLoad(params[3]);
+ tempParams[1] = params[0];
+ tempParams[2] = params[1];
+ tempParams[3] = params[2];
+ params[0] = tempParams[0];
+ params[1] = tempParams[1];
+ params[2] = tempParams[2];
+ params[3] = tempParams[3];
+
+ auto result = builder.emitCallInst(returnType, targetFunc, params);
+ builder.emitReturn(result);
+ return wrapperFunc;
+ }
+
+ void processCoopMatMapElementIFunc(IRCoopMatMapElementIFunc* inst)
+ {
+ IRBuilder builder{inst};
+ builder.setInsertBefore(inst);
+
+ auto ifuncCall = inst->getIFuncCall();
+
+ // `this` of the functor is optional.
+ // Skip the synthesis if `this` is not passed.
+ if (ifuncCall->getParamCount() > 3)
+ {
+ auto funcSynth = createWrapperFunctionForPerElement(builder, ifuncCall);
+ inst->setIFuncCall(funcSynth);
+ }
+ }
+
void legalizeSPIRVEntryPoint(IRFunc* func, IREntryPointDecoration* entryPointDecor)
{
auto stage = entryPointDecor->getProfile().getStage();
@@ -1743,6 +1819,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
case kIROp_BitfieldInsert:
processBitFieldOp(inst);
break;
+ case kIROp_CoopMatMapElementIFunc:
+ processCoopMatMapElementIFunc(as<IRCoopMatMapElementIFunc>(inst));
+ break;
case kIROp_DebugValue:
if (!isSimpleDataType(as<IRDebugValue>(inst)->getDebugVar()->getDataType()))
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 6d5bd9a25..85fe2fa04 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -4467,6 +4467,12 @@ IRInst* IRBuilder::emitGetTupleElement(IRType* type, IRInst* tuple, UInt element
return emitGetTupleElement(type, tuple, getIntValue(getIntType(), element));
}
+IRInst* IRBuilder::emitCoopMatMapElementFunc(IRType* type, IRInst* tuple, IRInst* func)
+{
+ IRInst* args[] = {tuple, func};
+ return emitIntrinsicInst(type, kIROp_CoopMatMapElementIFunc, 2, args);
+}
+
IRInst* IRBuilder::emitMakeResultError(IRType* resultType, IRInst* errorVal)
{
return emitIntrinsicInst(resultType, kIROp_MakeResultError, 1, &errorVal);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 758b21e93..bc00eb531 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1191,6 +1191,13 @@ top:
lowered = extractField(context, boundMemberInfo->type, base, fieldDeclRef);
goto top;
}
+ else if (auto methodDeclRef = declRef.as<CallableDecl>())
+ {
+ auto funcVal = emitDeclRef(context, declRef, boundMemberInfo->type);
+ SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple);
+ lowered = funcVal;
+ goto top;
+ }
else
{
@@ -4584,7 +4591,8 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
else if (auto callableDeclRef = declRef.as<CallableDecl>())
{
RefPtr<BoundMemberInfo> boundMemberInfo = new BoundMemberInfo();
- boundMemberInfo->type = nullptr;
+ boundMemberInfo->type =
+ lowerType(context, getResultType(context->astBuilder, callableDeclRef));
boundMemberInfo->base = loweredBase;
boundMemberInfo->declRef = callableDeclRef;