diff options
| author | Jay Kwak <82421531+jkwak-work@users.noreply.github.com> | 2025-05-29 16:36:49 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-29 16:36:49 -0700 |
| commit | 984d7f22f8a0909dc870c65bb927094c54f55402 (patch) | |
| tree | ab255bf44e14f6cbaa09522f90b12464f1c6a339 | |
| parent | f4d7954e088966c2ae8618b1cc17aac4d64ef013 (diff) | |
Implement MapElement for CoopMat (#7159)
With this PR, MapElement works for the following signatures:
- CoopMat<...>::MapElement(functype(...));
- CoopMat<...>::MapElement(capturing-lambda);
- CoopMat<...>::MapElement(not-capturing-lambda);
- Tuple<CoopMat<...>,...>::MapElement(functype(...));
- Tuple<CoopMat<...>,...>::MapElement(capturing-lambda);
- Tuple<CoopMat<...>,...>::MapElement(not-capturing-lambda);
| -rw-r--r-- | source/slang/hlsl.meta.slang | 45 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.cpp | 27 | ||||
| -rw-r--r-- | source/slang/slang-check-constraint.cpp | 243 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv-ops.h | 26 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 50 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-types.cpp | 33 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 79 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 10 | ||||
| -rw-r--r-- | tests/cooperative-matrix/map-element-single.slang | 40 | ||||
| -rw-r--r-- | tests/cooperative-matrix/map-element-tuple.slang | 68 | ||||
| -rw-r--r-- | tests/language-feature/tuple/tuple-expand-multiple.slang | 37 | ||||
| -rw-r--r-- | tests/language-feature/tuple/tuple-expand.slang | 25 |
15 files changed, 688 insertions, 17 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 801c61481..677b5d7bf 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -23415,6 +23415,23 @@ struct CoopMat }; } + __intrinsic_op($(kIROp_CoopMatMapElementIFunc)) + internal static This __MapElement< + TOperator, + TFunc : IFunc<T, uint32_t, uint32_t, T> + >( + This coopMat, + TOperator mapOp, + TFunc mapObj + ); + + This MapElement< + TFunc : IFunc<T, uint32_t, uint32_t, T> + >(TFunc mapOp) + { + return __MapElement(this, mapOp.operator(), mapOp); + } + // // Store // @@ -24276,8 +24293,36 @@ CoopMat<T, S, M, N, CoopMatMatrixUse.MatrixAccumulator> coopMatMulAdd< }; } +extension< + T : __BuiltinArithmeticType, + let S : MemoryScope, + let M : int, + let N : int, + let R : linalg.CoopMatMatrixUse, + each Ts : __BuiltinArithmeticType +> Tuple<linalg.CoopMat<T, S, M, N, R>, expand linalg.CoopMat<each Ts, S, M, N, R>> +{ + __intrinsic_op($(kIROp_CoopMatMapElementIFunc)) + CoopMat<T, S, M, N, R> MapElement(functype(uint32_t, uint32_t, T, expand each Ts)->T mapOp); + + __intrinsic_op($(kIROp_CoopMatMapElementIFunc)) + static CoopMat<T, S, M, N, R> __MapElement< + TOperator, + TFunc : IFunc<T, uint32_t, uint32_t, T, expand each Ts> + >(This tuple, TOperator mapOp, TFunc mapObj); + + [ForceInline] + CoopMat<T, S, M, N, R> MapElement< + TFunc : IFunc<T, uint32_t, uint32_t, T, expand each Ts> + >(TFunc mapOp) + { + return __MapElement(this, mapOp.operator(), mapOp); + } +}; + } // namespace linalg + // // Cooperative Vector // diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 18f3f90bc..7cfc6a67a 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -578,7 +578,19 @@ Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet s List<Type*> substParamTypes; for (Index pp = 0; pp < getParamCount(); pp++) { - substParamTypes.add(as<Type>(getParamType(pp)->substituteImpl(astBuilder, subst, &diff))); + auto substParamType = as<Type>(getParamType(pp)->substituteImpl(astBuilder, subst, &diff)); + if (auto typePack = as<ConcreteTypePack>(substParamType)) + { + // Unwrap the ConcreteTypePack and add each element as a parameter + for (Index i = 0; i < typePack->getTypeCount(); ++i) + { + substParamTypes.add(typePack->getElementType(i)); + } + } + else + { + substParamTypes.add(substParamType); + } } // early exit for no change... @@ -774,7 +786,18 @@ Val* ConcreteTypePack::_substituteImplOverride( for (Index i = 0; i < getTypeCount(); i++) { auto substType = as<Type>(getElementType(i)->substituteImpl(astBuilder, subst, &diff)); - substElementTypes.add(substType); + if (auto typePack = as<ConcreteTypePack>(substType)) + { + // Unwrap the ConcreteTypePack and add each element as a parameter + for (Index j = 0; j < typePack->getTypeCount(); ++j) + { + substElementTypes.add(typePack->getElementType(j)); + } + } + else + { + substElementTypes.add(substType); + } } if (!diff) return this; diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 6f9191135..6259e2fb8 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -1210,6 +1210,191 @@ void SemanticsVisitor::maybeUnifyUnconstraintIntParam( constraints.constraints.add(c); } +struct IndexSpan +{ + Index index; + Index count; + + IndexSpan() + : index(0), count(0) + { + } + IndexSpan(Index idx, Index cnt) + : index(idx), count(cnt) + { + } +}; + +struct IndexSpanPair +{ + IndexSpan first; + IndexSpan second; + + IndexSpanPair() {} + IndexSpanPair(IndexSpan f, IndexSpan s) + : first(f), second(s) + { + } +}; + +// Helper function to unwrap a type and count expandable types +static void unwrapTypeAndCountExpandable( + Type* type, + ShortList<Type*>& outTypes, + int& outExpandableCount) +{ + if (auto concretePack = as<ConcreteTypePack>(type)) + { + for (Index i = 0; i < concretePack->getTypeCount(); ++i) + { + outTypes.add(concretePack->getElementType(i)); + if (isAbstractTypePack(concretePack->getElementType(i))) + outExpandableCount++; + } + } + else if (isAbstractTypePack(type)) + { + outTypes.add(type); + outExpandableCount++; + } +} + +// Helper function to map type arguments between two types, handling expandable types +static bool matchTypeArgMapping( + Type* firstType, + Type* secondType, + ShortList<Type*>& outFlattenedFirst, + ShortList<Type*>& outFlattenedSecond, + ShortList<IndexSpanPair>& outMapping) +{ + // Unwrap and flatten the types + ShortList<Type*>& firstTypes = outFlattenedFirst; + ShortList<Type*>& secondTypes = outFlattenedSecond; + + // Count expandable types as we unwrap + int firstExpandableCount = 0; + int secondExpandableCount = 0; + + // Unwrap both types using the helper function + unwrapTypeAndCountExpandable(firstType, firstTypes, firstExpandableCount); + unwrapTypeAndCountExpandable(secondType, secondTypes, secondExpandableCount); + + // We need to figure out which side should be expanding. + // Consider the following cases, + // + // left = [ expand, expand ] + // right = [ int, float, expand ] + // when one side has more non-expandable types, the other side should expand to match it. + // in this case, "left" should expand to cover "int" and "float". + // + // left = [ int, float, expand, expand ] + // right = [ int, float, expand ] + // when the number of the non-expandable types are same, we want to expand side that has + // fewer expandable types. In this case, "right" should expand to cover the first "expand". + // + // left = ConcreteTypePack(ExpandType, ExpandType) + // right = ConcreteTypePack(int, bool, float, double). + // In this case, we shouldn't be mapping the first ExpandType to int and the second + // ExpandType to bool, float, double. Instead, they should evenly divide the second type + // pack, so we map first ExpandType with int, bool, and second ExpandType to float, double. + // + int firstCount = (int)firstTypes.getCount(); + int secondCount = (int)secondTypes.getCount(); + int countDifference = + (firstCount - firstExpandableCount) - (secondCount - secondExpandableCount); + + bool shouldExpandFirst = + (firstExpandableCount > 0) && + ((countDifference < 0) || + (countDifference == 0 && firstExpandableCount < secondExpandableCount)); + + bool shouldExpandSecond = + (secondExpandableCount > 0) && + ((countDifference > 0) || + (countDifference == 0 && firstExpandableCount > secondExpandableCount)); + + // We need to figure out how much types should match per each expandable type. + int typesPerExpand = 0; + if (shouldExpandSecond) + { + // More types on first, need to expand second + int countToMatch = countDifference + firstExpandableCount; + SLANG_ASSERT(secondExpandableCount != 0); + if (countToMatch % secondExpandableCount != 0) + return false; + typesPerExpand = countToMatch / secondExpandableCount; + } + else if (shouldExpandFirst) + { + // More types on second, need to expand first + int countToMatch = -countDifference + secondExpandableCount; + SLANG_ASSERT(firstExpandableCount != 0); + if (countToMatch % firstExpandableCount != 0) + return false; + typesPerExpand = countToMatch / firstExpandableCount; + } + // If countDifference == 0, no expansion needed + + // Generate the mapping + Index firstIndex = 0; + Index secondIndex = 0; + + while (firstIndex < firstCount && secondIndex < secondCount) + { + IndexSpanPair mapping; + + // Determine spans based on expandable types and count difference + if (shouldExpandFirst) + { + // Expanding first to match second + if (isAbstractTypePack(firstTypes[firstIndex])) + { + mapping.first = IndexSpan(firstIndex, 1); + mapping.second = IndexSpan(secondIndex, typesPerExpand); + secondIndex += typesPerExpand; + } + else + { + mapping.first = IndexSpan(firstIndex, 1); + mapping.second = IndexSpan(secondIndex, 1); + secondIndex++; + } + firstIndex++; + } + else if (shouldExpandSecond) + { + // Expanding second to match first + if (isAbstractTypePack(secondTypes[secondIndex])) + { + mapping.first = IndexSpan(firstIndex, typesPerExpand); + mapping.second = IndexSpan(secondIndex, 1); + firstIndex += typesPerExpand; + } + else + { + mapping.first = IndexSpan(firstIndex, 1); + mapping.second = IndexSpan(secondIndex, 1); + firstIndex++; + } + secondIndex++; + } + else + { + // No expansion needed + mapping.first = IndexSpan(firstIndex, 1); + mapping.second = IndexSpan(secondIndex, 1); + firstIndex++; + secondIndex++; + } + + outMapping.add(mapping); + } + + SLANG_ASSERT(!shouldExpandSecond || firstIndex == firstCount); + SLANG_ASSERT(!shouldExpandFirst || secondIndex == secondCount); + return true; +} + bool SemanticsVisitor::TryUnifyTypes( ConstraintSystem& constraints, ValUnificationContext unifyCtx, @@ -1244,7 +1429,63 @@ bool SemanticsVisitor::TryUnifyTypes( return TryUnifyConjunctionType(constraints, unifyCtx, fst, snd); } - // If one of the types is a type pack, we need to recursively unify the element types. + // Unwrap ConcreteTypePack and call TryUnifyTypes recursively. + ShortList<IndexSpanPair> typeMapping; + ShortList<Type*> flattenedFirst; + ShortList<Type*> flattenedSecond; + if (matchTypeArgMapping(fst, snd, flattenedFirst, flattenedSecond, typeMapping) && + typeMapping.getCount() > 1) + { + // Apply unification based on the mapping + for (const auto& mapping : typeMapping) + { + // Make sure it is one of three cases: 1:1, 1:N or N:1 + SLANG_ASSERT(mapping.first.count > 0 && mapping.second.count > 0); + SLANG_ASSERT(mapping.first.count == 1 || mapping.second.count == 1); + + // Helper function to create QualType from mapping span + auto mayPackAndGetQualType = [this]( + const IndexSpan& span, + ShortList<Type*>& typeList, + bool isLeftValue, + Type* otherType) -> QualType + { + // When the `otherType` is GenericTypePackParamDecl or ExpandType, + // we need to create a ConcreteTypePack so that we can handle them + // recursively. + if (isDeclRefTypeOf<GenericTypePackParamDecl>(otherType) || + as<ExpandType>(otherType)) + { + // Multiple types: create ConcreteTypePack + auto typesView = makeArrayView(&typeList[span.index], span.count); + auto typePack = m_astBuilder->getTypePack(typesView); + return QualType(typePack, isLeftValue); + } + + SLANG_ASSERT(span.count == 1); + return QualType(typeList[span.index], isLeftValue); + }; + + // Get the types directly from the mapping + QualType firstArg = mayPackAndGetQualType( + mapping.first, + flattenedFirst, + fst.isLeftValue, + flattenedSecond[mapping.second.index]); + QualType secondArg = mayPackAndGetQualType( + mapping.second, + flattenedSecond, + snd.isLeftValue, + flattenedFirst[mapping.first.index]); + + // Perform the unification + if (!TryUnifyTypes(constraints, unifyCtx, firstArg, secondArg)) + return false; + } + + return true; + } + if (auto fstTypePack = as<ConcreteTypePack>(fst)) { if (auto sndTypePack = as<ConcreteTypePack>(snd)) diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h index 81ad4bddf..4d33e045e 100644 --- a/source/slang/slang-emit-spirv-ops.h +++ b/source/slang/slang-emit-spirv-ops.h @@ -2666,4 +2666,30 @@ SpvInst* emitOpTypeNodePayloadArray(IRInst* inst, const T& type) type); } +// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCooperativeMatrixPerElementOpNV +template<typename T1, typename T2, typename T3, typename... TOperands> +SpvInst* emiOpCooperativeMatrixPerElementOp( + SpvInstParent* parent, + IRInst* inst, + const T1& idResultType, + const T2& matrix, + const T3& func, + const TOperands&... operands) +{ + static_assert(isSingular<T1>); + static_assert(isSingular<T2>); + static_assert(isSingular<T3>); + // Emit the instruction with a variable number of operands + return emitInst( + parent, + inst, + SpvOpCooperativeMatrixPerElementOpNV, + idResultType, + kResultID, + matrix, + func, + operands...); +} + + #endif // SLANG_IN_SPIRV_EMIT_CONTEXT diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index ba238985b..57ad1a988 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -4276,6 +4276,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_MakeArray: result = emitConstruct(parent, inst); break; + case kIROp_CoopMatMapElementIFunc: + result = emitCoopMatMapElementWithIFunc(parent, as<IRCoopMatMapElementIFunc>(inst)); + break; case kIROp_MakeTensorAddressingTensorLayout: result = emitOpCreateTensorLayout(parent, inst, getID(ensureInst(inst->getDataType()))); break; @@ -7698,6 +7701,53 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } } + SpvInst* emitCoopMatMapElementWithIFunc(SpvInstParent* parent, IRCoopMatMapElementIFunc* inst) + { + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_cooperative_matrix2")); + requireSPIRVCapability(SpvCapabilityCooperativeMatrixPerElementOperationsNV); + + IRInst* matOrTuple = inst->getCoopMat(); + + IRInst* mat0 = nullptr; + + UInt tupleCount = 0; + IRInst* tuple = as<IRMakeStruct>(matOrTuple); + if (tuple) + { + mat0 = tuple->getOperand(0); + tupleCount = tuple->getOperandCount(); + } + else + { + mat0 = matOrTuple; + } + + auto funcCall = inst->getIFuncCall(); + + IRInst* ifuncThis = nullptr; + if (inst->getOperandCount() > 2) + ifuncThis = inst->getIFuncThis(); + + return emitInstCustomOperandFunc( + parent, + inst, + SpvOpCooperativeMatrixPerElementOpNV, + [&]() + { + emitOperand(mat0->getDataType()); + emitOperand(kResultID); + + emitOperand(mat0); + emitOperand(funcCall); + + if (ifuncThis) + emitOperand(ifuncThis); + + for (UInt i = 1; i < tupleCount; i++) + emitOperand(tuple->getOperand(i)); + }); + } + SpvInst* emitSplat(SpvInstParent* parent, IRInst* inst, IRInst* scalar, IRIntegerValue numElems) { const auto scalarTy = as<IRBasicType>(scalar->getDataType()); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 419c8d59d..3b45d46b3 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -789,6 +789,8 @@ INST(AllocateTorchTensor, allocTorchTensor, 0, 0) INST(TorchGetCudaStream, TorchGetCudaStream, 0, 0) INST(TorchTensorGetView, TorchTensorGetView, 0, 0) +INST(CoopMatMapElementIFunc, CoopMatMapElementIFunc, 2, 0) + INST(AllocateOpaqueHandle, allocateOpaqueHandle, 0, 0) // Return the register index thtat a resource is bound to. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 2ab4db980..2fff4e451 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3124,6 +3124,18 @@ struct IRMakeCoopVector : IRInst IR_LEAF_ISA(MakeCoopVector) }; +struct IRCoopMatMapElementIFunc : IRInst +{ + IR_LEAF_ISA(CoopMatMapElementIFunc) + IRInst* getCoopMat() { return getOperand(0); } + IRInst* getTuple() { return getOperand(0); } + IRFunc* getIFuncCall() { return as<IRFunc>(getOperand(1)); } + IRInst* getIFuncThis() { return getOperand(2); } + + bool hasIFuncThis() { return getOperandCount() > 2; } + void setIFuncCall(IRFunc* func) { setOperand(1, func); } +}; + // An Instruction that creates a differential pair value from a // primal and differential. @@ -4241,6 +4253,8 @@ public: IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, UInt element); IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, IRInst* element); + IRInst* emitCoopMatMapElementFunc(IRType* type, IRInst* tuple, IRInst* func); + IRInst* emitGetElement(IRType* type, IRInst* arrayLikeType, IRIntegerValue element); IRInst* emitGetElementPtr(IRType* type, IRInst* arrayLikeType, IRIntegerValue element); diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index af29d1998..529770acc 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -2282,6 +2282,36 @@ static LegalVal legalizeGlobalParam( IRTypeLegalizationContext* context, IRGlobalParam* irGlobalParam); +static LegalVal legalizeCoopMatMapElementIFunc( + IRTypeLegalizationContext* context, + IRCoopMatMapElementIFunc* inst) +{ + // When the functor object is a lambda with no captures, + // it will be removed as a part of the legalization process, + // because it is a zero-sized struct. + // We need to explicitly remove it from its user. + if (inst->hasIFuncThis()) + { + // Check if `this` is valid. + auto legalArg = legalizeOperand(context, inst->getIFuncThis()); + if (legalArg.flavor == LegalVal::Flavor::none) + { + // If `this` is not valid, remove it from IR. + IRBuilder builder{inst}; + builder.setInsertBefore(inst); + auto newInst = builder.emitCoopMatMapElementFunc( + inst->getFullType(), + inst->getOperand(0), + inst->getOperand(1)); + + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return LegalVal::simple(newInst); + } + } + return LegalVal::simple(inst); +} + static LegalVal legalizeInst(IRTypeLegalizationContext* context, IRInst* inst) { // Any additional instructions we need to emit @@ -2293,6 +2323,9 @@ static LegalVal legalizeInst(IRTypeLegalizationContext* context, IRInst* inst) // Special-case certain operations switch (inst->getOp()) { + case kIROp_CoopMatMapElementIFunc: + return legalizeCoopMatMapElementIFunc(context, cast<IRCoopMatMapElementIFunc>(inst)); + case kIROp_Var: return legalizeLocalVar(context, cast<IRVar>(inst)); diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 20b7f795f..31742cbde 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -1506,6 +1506,82 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } } + // Creates a new function with a different parameter order. + // OpCooperativeMatrixPerElementOpNV expects a callback function to have the parameter + // in the following order: + // return-type CallbackFunction(uint, uint, coopmat1, functorThis, coopmat2, coopmat3, ...) + // But what we have by default is: + // return-type CallbackFunction(functorThis, uint, uint, coopmat1, coopmat2, coopmat3, ...) + // + // The new function will do the following, + // return-type newCallback(uint a, uint b, T mat1, TFunc f, T mat2, T mat3, ...) + // { return targetFunc(f, a, b, mat1, mat2, mat3, ...); } + IRFunc* createWrapperFunctionForPerElement(IRBuilder& builder, IRFunc* targetFunc) + { + List<IRType*> paramTypes; + for (UInt i = 0; i < targetFunc->getParamCount(); i++) + { + paramTypes.add(targetFunc->getParamType(i)); + } + + SLANG_ASSERT(paramTypes.getCount() >= 4); + + IRType* tempTypes[4]; + tempTypes[3] = builder.getPtrType(paramTypes[0]); + tempTypes[0] = paramTypes[1]; + tempTypes[1] = paramTypes[2]; + tempTypes[2] = paramTypes[3]; + paramTypes[0] = tempTypes[0]; + paramTypes[1] = tempTypes[1]; + paramTypes[2] = tempTypes[2]; + paramTypes[3] = tempTypes[3]; + + IRType* returnType = targetFunc->getDataType()->getResultType(); + + IRBuilderInsertLocScope insertLocScope(&builder); + auto wrapperFunc = builder.createFunc(); + builder.setDataType(wrapperFunc, builder.getFuncType(paramTypes, returnType)); + builder.setInsertInto(wrapperFunc); + auto block = builder.emitBlock(); + builder.setInsertInto(block); + + List<IRInst*> params; + for (Index i = 0; i < paramTypes.getCount(); i++) + { + params.add(builder.emitParam(paramTypes[i])); + } + + IRInst* tempParams[4]; + tempParams[0] = builder.emitLoad(params[3]); + tempParams[1] = params[0]; + tempParams[2] = params[1]; + tempParams[3] = params[2]; + params[0] = tempParams[0]; + params[1] = tempParams[1]; + params[2] = tempParams[2]; + params[3] = tempParams[3]; + + auto result = builder.emitCallInst(returnType, targetFunc, params); + builder.emitReturn(result); + return wrapperFunc; + } + + void processCoopMatMapElementIFunc(IRCoopMatMapElementIFunc* inst) + { + IRBuilder builder{inst}; + builder.setInsertBefore(inst); + + auto ifuncCall = inst->getIFuncCall(); + + // `this` of the functor is optional. + // Skip the synthesis if `this` is not passed. + if (ifuncCall->getParamCount() > 3) + { + auto funcSynth = createWrapperFunctionForPerElement(builder, ifuncCall); + inst->setIFuncCall(funcSynth); + } + } + void legalizeSPIRVEntryPoint(IRFunc* func, IREntryPointDecoration* entryPointDecor) { auto stage = entryPointDecor->getProfile().getStage(); @@ -1743,6 +1819,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase case kIROp_BitfieldInsert: processBitFieldOp(inst); break; + case kIROp_CoopMatMapElementIFunc: + processCoopMatMapElementIFunc(as<IRCoopMatMapElementIFunc>(inst)); + break; case kIROp_DebugValue: if (!isSimpleDataType(as<IRDebugValue>(inst)->getDebugVar()->getDataType())) diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 6d5bd9a25..85fe2fa04 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4467,6 +4467,12 @@ IRInst* IRBuilder::emitGetTupleElement(IRType* type, IRInst* tuple, UInt element return emitGetTupleElement(type, tuple, getIntValue(getIntType(), element)); } +IRInst* IRBuilder::emitCoopMatMapElementFunc(IRType* type, IRInst* tuple, IRInst* func) +{ + IRInst* args[] = {tuple, func}; + return emitIntrinsicInst(type, kIROp_CoopMatMapElementIFunc, 2, args); +} + IRInst* IRBuilder::emitMakeResultError(IRType* resultType, IRInst* errorVal) { return emitIntrinsicInst(resultType, kIROp_MakeResultError, 1, &errorVal); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 758b21e93..bc00eb531 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1191,6 +1191,13 @@ top: lowered = extractField(context, boundMemberInfo->type, base, fieldDeclRef); goto top; } + else if (auto methodDeclRef = declRef.as<CallableDecl>()) + { + auto funcVal = emitDeclRef(context, declRef, boundMemberInfo->type); + SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple); + lowered = funcVal; + goto top; + } else { @@ -4584,7 +4591,8 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> else if (auto callableDeclRef = declRef.as<CallableDecl>()) { RefPtr<BoundMemberInfo> boundMemberInfo = new BoundMemberInfo(); - boundMemberInfo->type = nullptr; + boundMemberInfo->type = + lowerType(context, getResultType(context->astBuilder, callableDeclRef)); boundMemberInfo->base = loweredBase; boundMemberInfo->declRef = callableDeclRef; diff --git a/tests/cooperative-matrix/map-element-single.slang b/tests/cooperative-matrix/map-element-single.slang index 1661ee105..ecf35953e 100644 --- a/tests/cooperative-matrix/map-element-single.slang +++ b/tests/cooperative-matrix/map-element-single.slang @@ -1,12 +1,14 @@ -//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -render-feature cooperative-matrix-per-element-operations -Xslang -DTEST_MODE=0 -//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -render-feature cooperative-matrix-per-element-operations -Xslang -DTEST_MODE=1 -//DISABLE_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -render-feature cooperative-matrix-per-element-operations -Xslang -DTEST_MODE=2 +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -Xslang -DTEST_MODE=0 -render-feature cooperative-matrix-per-element-operations +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -Xslang -DTEST_MODE=1 -render-feature cooperative-matrix-per-element-operations +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -Xslang -DTEST_MODE=2 -render-feature cooperative-matrix-per-element-operations +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -Xslang -DTEST_MODE=3 -render-feature cooperative-matrix-per-element-operations +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -Xslang -DTEST_MODE=4 -render-feature cooperative-matrix-per-element-operations //CHECK: type: int32_t -//CHECK-NEXT: 2 -//CHECK-NEXT: 4 -//CHECK-NEXT: 6 //CHECK-NEXT: 8 +//CHECK-NEXT: 10 +//CHECK-NEXT: 12 +//CHECK-NEXT: 14 //TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4),name=input1 StructuredBuffer<int> input1; @@ -20,7 +22,7 @@ typealias CoopMatType = CoopMat<int, MemoryScope.Subgroup, 16, 16, CoopMatMatrix int MapOp(uint32_t row, uint32_t col, int value) { - return value * 2; + return value * 2 + 1 + 2 + 3; } [numthreads(32, 1, 1)] @@ -29,21 +31,33 @@ void computeMain() let stride = 16; CoopMatType mat1 = CoopMatType.Load<CoopMatMatrixLayout.RowMajor>(input1, 0, stride); + // Testing the capturing lambda + int c0 = 1; + int c1 = 2; + int c2 = 3; + CoopMatType result; #if TEST_MODE == 0 result = mat1.MapElement(MapOp); #elif TEST_MODE == 1 - // Lambda through IFunc. - // TODO: Not working due to issue #7024 - IFunc<int, uint32_t, uint32_t, int> func = ((uint32_t row, uint32_t column, int value) => value * 2); + // Lambda via a temp variable (no capture) + let func = ((uint32_t row, uint32_t column, int value) => value * 2 + 1 + 2 + 3); result = mat1.MapElement(func); #elif TEST_MODE == 2 - // Directly use lambda. - // TODO: Not working due to issue #7024 - result = mat1.MapElement((uint32_t row, uint32_t column, int value) => (int)(value)); + // Directly use lambda (no capture) + result = mat1.MapElement((uint32_t row, uint32_t column, int value) => value * 2 + 1 + 2 + 3); + +#elif TEST_MODE == 3 + // Lambda via a temp variable (capture) + let func = ((uint32_t row, uint32_t column, int value) => value * 2 + c0 + c1 + c2); + result = mat1.MapElement(func); + +#elif TEST_MODE == 4 + // Directly use lambda (capture) + result = mat1.MapElement((uint32_t row, uint32_t column, int value) => value * 2 + c0 + c1 + c2); #endif result.Store<CoopMatMatrixLayout.RowMajor>(outputBuffer, 0, stride); diff --git a/tests/cooperative-matrix/map-element-tuple.slang b/tests/cooperative-matrix/map-element-tuple.slang new file mode 100644 index 000000000..06ab99d8f --- /dev/null +++ b/tests/cooperative-matrix/map-element-tuple.slang @@ -0,0 +1,68 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -Xslang -DTEST_MODE=0 -render-feature cooperative-matrix-per-element-operations +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -Xslang -DTEST_MODE=1 -render-feature cooperative-matrix-per-element-operations +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -Xslang -DTEST_MODE=2 -render-feature cooperative-matrix-per-element-operations +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -Xslang -DTEST_MODE=3 -render-feature cooperative-matrix-per-element-operations +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly -Xslang -DTEST_MODE=4 -render-feature cooperative-matrix-per-element-operations + +//CHECK:type: int32_t +//CHECK-NEXT:9 +//CHECK-NEXT:12 +//CHECK-NEXT:15 +//CHECK-NEXT:14 + +//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4),name=input1 +StructuredBuffer<int> input1; + +//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4),name=input2 +StructuredBuffer<int> input2; + +//TEST_INPUT:ubuffer(data=[2 3 4 1], stride=4),name=input3 +StructuredBuffer<int> input3; + +//TEST_INPUT:ubuffer(stride=4, count=256):out,name=outputBuffer +RWStructuredBuffer<int32_t> outputBuffer; + +using namespace linalg; + +typealias CoopMatType = CoopMat<int, MemoryScope.Subgroup, 16, 16, CoopMatMatrixUse.MatrixAccumulator>; + +int MapOp(uint32_t row, uint32_t col, int a, int b, int c) +{ + return a + b + c + 1 + 2 + 3; +} + +[numthreads(32, 1, 1)] +void computeMain() +{ + let stride = 16; + let mat1 = CoopMatType.Load<CoopMatMatrixLayout.RowMajor>(input1, 0, stride); + let mat2 = CoopMatType.Load<CoopMatMatrixLayout.RowMajor>(input2, 0, stride); + let mat3 = CoopMatType.Load<CoopMatMatrixLayout.RowMajor>(input3, 0, stride); + + // Testing the capturing lambda + int c0 = 1; + int c1 = 2; + int c2 = 3; + + CoopMatType result; + +#if TEST_MODE == 0 + result = makeTuple(mat1, mat2, mat3).MapElement(MapOp); + +#elif TEST_MODE == 1 + let f = ((uint32_t x, uint32_t y, int a, int b, int c) => a + b + c + 1 + 2 + 3); + result = makeTuple(mat1, mat2, mat3).MapElement(f); + +#elif TEST_MODE == 2 + result = makeTuple(mat1, mat2, mat3).MapElement((uint32_t x, uint32_t y, int a, int b, int c) => a + b + c + 1 + 2 + 3); + +#elif TEST_MODE == 3 + let f = ((uint32_t x, uint32_t y, int a, int b, int c) => a + b + c + c0 + c1 + c2); + result = makeTuple(mat1, mat2, mat3).MapElement(f); + +#elif TEST_MODE == 4 + result = makeTuple(mat1, mat2, mat3).MapElement((uint32_t x, uint32_t y, int a, int b, int c) => a + b + c + c0 + c1 + c2); +#endif + + result.Store<CoopMatMatrixLayout.RowMajor>(outputBuffer, 0, stride); +} diff --git a/tests/language-feature/tuple/tuple-expand-multiple.slang b/tests/language-feature/tuple/tuple-expand-multiple.slang new file mode 100644 index 000000000..9392e2d66 --- /dev/null +++ b/tests/language-feature/tuple/tuple-expand-multiple.slang @@ -0,0 +1,37 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -output-using-type + +//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0 0 0], stride=4) +RWStructuredBuffer<int> outputBuffer; + +extension< + T0 : __BuiltinArithmeticType, + each Ts0 : __BuiltinArithmeticType, + each Ts1 : __BuiltinArithmeticType +> Tuple<T0, T0, expand each Ts0, expand each Ts1> +{ + static int getSize_Ts0() { return countof(Ts0); } + static int getSize_Ts1() { return countof(Ts1); } +} + +[numthreads(1,1,1)] +void computeMain() +{ + int i = 2; + float f0 = 3, f1 = 5; + uint ui0 = 4, ui1 = 6; + + let s0 = makeTuple(i, i); // T, T + let s1 = makeTuple(i, i, f0, ui0); // T, T, Ts0, Ts1 + let s2 = makeTuple(i, i, f0, ui0, f1, ui1); // T, T, Ts0, Ts0, Ts1, Ts1 + + outputBuffer[0] = s0.getSize_Ts0(); + outputBuffer[1] = s0.getSize_Ts1(); + outputBuffer[2] = s1.getSize_Ts0(); + outputBuffer[3] = s1.getSize_Ts1(); + outputBuffer[4] = s2.getSize_Ts0(); + outputBuffer[5] = s2.getSize_Ts1(); + + // CHECK-COUNT-2:0 + // CHECK-COUNT-2:1 + // CHECK-COUNT-2:2 +} diff --git a/tests/language-feature/tuple/tuple-expand.slang b/tests/language-feature/tuple/tuple-expand.slang new file mode 100644 index 000000000..c9ea26161 --- /dev/null +++ b/tests/language-feature/tuple/tuple-expand.slang @@ -0,0 +1,25 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -output-using-type + +//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0], stride=4) +RWStructuredBuffer<int> outputBuffer; + +struct X<B, each T, each U> +{ + int getTSize() { return countof(T); } + int getUSize() { return countof(U); } +} + +func foo<each T, each U>() -> X<bool, expand Ptr<each T>, int, expand Ptr<each U>, float> // unify +{ + return {}; +} + +[numthreads(1,1,1)] +void computeMain() +{ + let x = foo<int, float>(); + + outputBuffer[0] = x.getTSize(); + outputBuffer[1] = x.getUSize(); + // CHECK-COUNT-2: 2 +} |
