diff options
| author | Jay Kwak <82421531+jkwak-work@users.noreply.github.com> | 2025-05-29 16:36:49 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-29 16:36:49 -0700 |
| commit | 984d7f22f8a0909dc870c65bb927094c54f55402 (patch) | |
| tree | ab255bf44e14f6cbaa09522f90b12464f1c6a339 /source | |
| parent | f4d7954e088966c2ae8618b1cc17aac4d64ef013 (diff) | |
Implement MapElement for CoopMat (#7159)
With this PR, MapElement works for the following signatures:
- CoopMat<...>::MapElement(functype(...));
- CoopMat<...>::MapElement(capturing-lambda);
- CoopMat<...>::MapElement(not-capturing-lambda);
- Tuple<CoopMat<...>,...>::MapElement(functype(...));
- Tuple<CoopMat<...>,...>::MapElement(capturing-lambda);
- Tuple<CoopMat<...>,...>::MapElement(not-capturing-lambda);
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/hlsl.meta.slang | 45 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.cpp | 27 | ||||
| -rw-r--r-- | source/slang/slang-check-constraint.cpp | 243 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv-ops.h | 26 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 50 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-types.cpp | 33 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 79 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 10 |
11 files changed, 531 insertions, 4 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 801c61481..677b5d7bf 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -23415,6 +23415,23 @@ struct CoopMat }; } + __intrinsic_op($(kIROp_CoopMatMapElementIFunc)) + internal static This __MapElement< + TOperator, + TFunc : IFunc<T, uint32_t, uint32_t, T> + >( + This coopMat, + TOperator mapOp, + TFunc mapObj + ); + + This MapElement< + TFunc : IFunc<T, uint32_t, uint32_t, T> + >(TFunc mapOp) + { + return __MapElement(this, mapOp.operator(), mapOp); + } + // // Store // @@ -24276,8 +24293,36 @@ CoopMat<T, S, M, N, CoopMatMatrixUse.MatrixAccumulator> coopMatMulAdd< }; } +extension< + T : __BuiltinArithmeticType, + let S : MemoryScope, + let M : int, + let N : int, + let R : linalg.CoopMatMatrixUse, + each Ts : __BuiltinArithmeticType +> Tuple<linalg.CoopMat<T, S, M, N, R>, expand linalg.CoopMat<each Ts, S, M, N, R>> +{ + __intrinsic_op($(kIROp_CoopMatMapElementIFunc)) + CoopMat<T, S, M, N, R> MapElement(functype(uint32_t, uint32_t, T, expand each Ts)->T mapOp); + + __intrinsic_op($(kIROp_CoopMatMapElementIFunc)) + static CoopMat<T, S, M, N, R> __MapElement< + TOperator, + TFunc : IFunc<T, uint32_t, uint32_t, T, expand each Ts> + >(This tuple, TOperator mapOp, TFunc mapObj); + + [ForceInline] + CoopMat<T, S, M, N, R> MapElement< + TFunc : IFunc<T, uint32_t, uint32_t, T, expand each Ts> + >(TFunc mapOp) + { + return __MapElement(this, mapOp.operator(), mapOp); + } +}; + } // namespace linalg + // // Cooperative Vector // diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 18f3f90bc..7cfc6a67a 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -578,7 +578,19 @@ Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet s List<Type*> substParamTypes; for (Index pp = 0; pp < getParamCount(); pp++) { - substParamTypes.add(as<Type>(getParamType(pp)->substituteImpl(astBuilder, subst, &diff))); + auto substParamType = as<Type>(getParamType(pp)->substituteImpl(astBuilder, subst, &diff)); + if (auto typePack = as<ConcreteTypePack>(substParamType)) + { + // Unwrap the ConcreteTypePack and add each element as a parameter + for (Index i = 0; i < typePack->getTypeCount(); ++i) + { + substParamTypes.add(typePack->getElementType(i)); + } + } + else + { + substParamTypes.add(substParamType); + } } // early exit for no change... @@ -774,7 +786,18 @@ Val* ConcreteTypePack::_substituteImplOverride( for (Index i = 0; i < getTypeCount(); i++) { auto substType = as<Type>(getElementType(i)->substituteImpl(astBuilder, subst, &diff)); - substElementTypes.add(substType); + if (auto typePack = as<ConcreteTypePack>(substType)) + { + // Unwrap the ConcreteTypePack and add each element as a parameter + for (Index j = 0; j < typePack->getTypeCount(); ++j) + { + substElementTypes.add(typePack->getElementType(j)); + } + } + else + { + substElementTypes.add(substType); + } } if (!diff) return this; diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 6f9191135..6259e2fb8 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -1210,6 +1210,191 @@ void SemanticsVisitor::maybeUnifyUnconstraintIntParam( constraints.constraints.add(c); } +struct IndexSpan +{ + Index index; + Index count; + + IndexSpan() + : index(0), count(0) + { + } + IndexSpan(Index idx, Index cnt) + : index(idx), count(cnt) + { + } +}; + +struct IndexSpanPair +{ + IndexSpan first; + IndexSpan second; + + IndexSpanPair() {} + IndexSpanPair(IndexSpan f, IndexSpan s) + : first(f), second(s) + { + } +}; + +// Helper function to unwrap a type and count expandable types +static void unwrapTypeAndCountExpandable( + Type* type, + ShortList<Type*>& outTypes, + int& outExpandableCount) +{ + if (auto concretePack = as<ConcreteTypePack>(type)) + { + for (Index i = 0; i < concretePack->getTypeCount(); ++i) + { + outTypes.add(concretePack->getElementType(i)); + if (isAbstractTypePack(concretePack->getElementType(i))) + outExpandableCount++; + } + } + else if (isAbstractTypePack(type)) + { + outTypes.add(type); + outExpandableCount++; + } +} + +// Helper function to map type arguments between two types, handling expandable types +static bool matchTypeArgMapping( + Type* firstType, + Type* secondType, + ShortList<Type*>& outFlattenedFirst, + ShortList<Type*>& outFlattenedSecond, + ShortList<IndexSpanPair>& outMapping) +{ + // Unwrap and flatten the types + ShortList<Type*>& firstTypes = outFlattenedFirst; + ShortList<Type*>& secondTypes = outFlattenedSecond; + + // Count expandable types as we unwrap + int firstExpandableCount = 0; + int secondExpandableCount = 0; + + // Unwrap both types using the helper function + unwrapTypeAndCountExpandable(firstType, firstTypes, firstExpandableCount); + unwrapTypeAndCountExpandable(secondType, secondTypes, secondExpandableCount); + + // We need to figure out which side should be expanding. + // Consider the following cases, + // + // left = [ expand, expand ] + // right = [ int, float, expand ] + // when one side has more non-expandable types, the other side should expand to match it. + // in this case, "left" should expand to cover "int" and "float". + // + // left = [ int, float, expand, expand ] + // right = [ int, float, expand ] + // when the number of the non-expandable types are same, we want to expand side that has + // fewer expandable types. In this case, "right" should expand to cover the first "expand". + // + // left = ConcreteTypePack(ExpandType, ExpandType) + // right = ConcreteTypePack(int, bool, float, double). + // In this case, we shouldn't be mapping the first ExpandType to int and the second + // ExpandType to bool, float, double. Instead, they should evenly divide the second type + // pack, so we map first ExpandType with int, bool, and second ExpandType to float, double. + // + int firstCount = (int)firstTypes.getCount(); + int secondCount = (int)secondTypes.getCount(); + int countDifference = + (firstCount - firstExpandableCount) - (secondCount - secondExpandableCount); + + bool shouldExpandFirst = + (firstExpandableCount > 0) && + ((countDifference < 0) || + (countDifference == 0 && firstExpandableCount < secondExpandableCount)); + + bool shouldExpandSecond = + (secondExpandableCount > 0) && + ((countDifference > 0) || + (countDifference == 0 && firstExpandableCount > secondExpandableCount)); + + // We need to figure out how much types should match per each expandable type. + int typesPerExpand = 0; + if (shouldExpandSecond) + { + // More types on first, need to expand second + int countToMatch = countDifference + firstExpandableCount; + SLANG_ASSERT(secondExpandableCount != 0); + if (countToMatch % secondExpandableCount != 0) + return false; + typesPerExpand = countToMatch / secondExpandableCount; + } + else if (shouldExpandFirst) + { + // More types on second, need to expand first + int countToMatch = -countDifference + secondExpandableCount; + SLANG_ASSERT(firstExpandableCount != 0); + if (countToMatch % firstExpandableCount != 0) + return false; + typesPerExpand = countToMatch / firstExpandableCount; + } + // If countDifference == 0, no expansion needed + + // Generate the mapping + Index firstIndex = 0; + Index secondIndex = 0; + + while (firstIndex < firstCount && secondIndex < secondCount) + { + IndexSpanPair mapping; + + // Determine spans based on expandable types and count difference + if (shouldExpandFirst) + { + // Expanding first to match second + if (isAbstractTypePack(firstTypes[firstIndex])) + { + mapping.first = IndexSpan(firstIndex, 1); + mapping.second = IndexSpan(secondIndex, typesPerExpand); + secondIndex += typesPerExpand; + } + else + { + mapping.first = IndexSpan(firstIndex, 1); + mapping.second = IndexSpan(secondIndex, 1); + secondIndex++; + } + firstIndex++; + } + else if (shouldExpandSecond) + { + // Expanding second to match first + if (isAbstractTypePack(secondTypes[secondIndex])) + { + mapping.first = IndexSpan(firstIndex, typesPerExpand); + mapping.second = IndexSpan(secondIndex, 1); + firstIndex += typesPerExpand; + } + else + { + mapping.first = IndexSpan(firstIndex, 1); + mapping.second = IndexSpan(secondIndex, 1); + firstIndex++; + } + secondIndex++; + } + else + { + // No expansion needed + mapping.first = IndexSpan(firstIndex, 1); + mapping.second = IndexSpan(secondIndex, 1); + firstIndex++; + secondIndex++; + } + + outMapping.add(mapping); + } + + SLANG_ASSERT(!shouldExpandSecond || firstIndex == firstCount); + SLANG_ASSERT(!shouldExpandFirst || secondIndex == secondCount); + return true; +} + bool SemanticsVisitor::TryUnifyTypes( ConstraintSystem& constraints, ValUnificationContext unifyCtx, @@ -1244,7 +1429,63 @@ bool SemanticsVisitor::TryUnifyTypes( return TryUnifyConjunctionType(constraints, unifyCtx, fst, snd); } - // If one of the types is a type pack, we need to recursively unify the element types. + // Unwrap ConcreteTypePack and call TryUnifyTypes recursively. + ShortList<IndexSpanPair> typeMapping; + ShortList<Type*> flattenedFirst; + ShortList<Type*> flattenedSecond; + if (matchTypeArgMapping(fst, snd, flattenedFirst, flattenedSecond, typeMapping) && + typeMapping.getCount() > 1) + { + // Apply unification based on the mapping + for (const auto& mapping : typeMapping) + { + // Make sure it is one of three cases: 1:1, 1:N or N:1 + SLANG_ASSERT(mapping.first.count > 0 && mapping.second.count > 0); + SLANG_ASSERT(mapping.first.count == 1 || mapping.second.count == 1); + + // Helper function to create QualType from mapping span + auto mayPackAndGetQualType = [this]( + const IndexSpan& span, + ShortList<Type*>& typeList, + bool isLeftValue, + Type* otherType) -> QualType + { + // When the `otherType` is GenericTypePackParamDecl or ExpandType, + // we need to create a ConcreteTypePack so that we can handle them + // recursively. + if (isDeclRefTypeOf<GenericTypePackParamDecl>(otherType) || + as<ExpandType>(otherType)) + { + // Multiple types: create ConcreteTypePack + auto typesView = makeArrayView(&typeList[span.index], span.count); + auto typePack = m_astBuilder->getTypePack(typesView); + return QualType(typePack, isLeftValue); + } + + SLANG_ASSERT(span.count == 1); + return QualType(typeList[span.index], isLeftValue); + }; + + // Get the types directly from the mapping + QualType firstArg = mayPackAndGetQualType( + mapping.first, + flattenedFirst, + fst.isLeftValue, + flattenedSecond[mapping.second.index]); + QualType secondArg = mayPackAndGetQualType( + mapping.second, + flattenedSecond, + snd.isLeftValue, + flattenedFirst[mapping.first.index]); + + // Perform the unification + if (!TryUnifyTypes(constraints, unifyCtx, firstArg, secondArg)) + return false; + } + + return true; + } + if (auto fstTypePack = as<ConcreteTypePack>(fst)) { if (auto sndTypePack = as<ConcreteTypePack>(snd)) diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h index 81ad4bddf..4d33e045e 100644 --- a/source/slang/slang-emit-spirv-ops.h +++ b/source/slang/slang-emit-spirv-ops.h @@ -2666,4 +2666,30 @@ SpvInst* emitOpTypeNodePayloadArray(IRInst* inst, const T& type) type); } +// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCooperativeMatrixPerElementOpNV +template<typename T1, typename T2, typename T3, typename... TOperands> +SpvInst* emiOpCooperativeMatrixPerElementOp( + SpvInstParent* parent, + IRInst* inst, + const T1& idResultType, + const T2& matrix, + const T3& func, + const TOperands&... operands) +{ + static_assert(isSingular<T1>); + static_assert(isSingular<T2>); + static_assert(isSingular<T3>); + // Emit the instruction with a variable number of operands + return emitInst( + parent, + inst, + SpvOpCooperativeMatrixPerElementOpNV, + idResultType, + kResultID, + matrix, + func, + operands...); +} + + #endif // SLANG_IN_SPIRV_EMIT_CONTEXT diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index ba238985b..57ad1a988 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -4276,6 +4276,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_MakeArray: result = emitConstruct(parent, inst); break; + case kIROp_CoopMatMapElementIFunc: + result = emitCoopMatMapElementWithIFunc(parent, as<IRCoopMatMapElementIFunc>(inst)); + break; case kIROp_MakeTensorAddressingTensorLayout: result = emitOpCreateTensorLayout(parent, inst, getID(ensureInst(inst->getDataType()))); break; @@ -7698,6 +7701,53 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex } } + SpvInst* emitCoopMatMapElementWithIFunc(SpvInstParent* parent, IRCoopMatMapElementIFunc* inst) + { + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_cooperative_matrix2")); + requireSPIRVCapability(SpvCapabilityCooperativeMatrixPerElementOperationsNV); + + IRInst* matOrTuple = inst->getCoopMat(); + + IRInst* mat0 = nullptr; + + UInt tupleCount = 0; + IRInst* tuple = as<IRMakeStruct>(matOrTuple); + if (tuple) + { + mat0 = tuple->getOperand(0); + tupleCount = tuple->getOperandCount(); + } + else + { + mat0 = matOrTuple; + } + + auto funcCall = inst->getIFuncCall(); + + IRInst* ifuncThis = nullptr; + if (inst->getOperandCount() > 2) + ifuncThis = inst->getIFuncThis(); + + return emitInstCustomOperandFunc( + parent, + inst, + SpvOpCooperativeMatrixPerElementOpNV, + [&]() + { + emitOperand(mat0->getDataType()); + emitOperand(kResultID); + + emitOperand(mat0); + emitOperand(funcCall); + + if (ifuncThis) + emitOperand(ifuncThis); + + for (UInt i = 1; i < tupleCount; i++) + emitOperand(tuple->getOperand(i)); + }); + } + SpvInst* emitSplat(SpvInstParent* parent, IRInst* inst, IRInst* scalar, IRIntegerValue numElems) { const auto scalarTy = as<IRBasicType>(scalar->getDataType()); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 419c8d59d..3b45d46b3 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -789,6 +789,8 @@ INST(AllocateTorchTensor, allocTorchTensor, 0, 0) INST(TorchGetCudaStream, TorchGetCudaStream, 0, 0) INST(TorchTensorGetView, TorchTensorGetView, 0, 0) +INST(CoopMatMapElementIFunc, CoopMatMapElementIFunc, 2, 0) + INST(AllocateOpaqueHandle, allocateOpaqueHandle, 0, 0) // Return the register index thtat a resource is bound to. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 2ab4db980..2fff4e451 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3124,6 +3124,18 @@ struct IRMakeCoopVector : IRInst IR_LEAF_ISA(MakeCoopVector) }; +struct IRCoopMatMapElementIFunc : IRInst +{ + IR_LEAF_ISA(CoopMatMapElementIFunc) + IRInst* getCoopMat() { return getOperand(0); } + IRInst* getTuple() { return getOperand(0); } + IRFunc* getIFuncCall() { return as<IRFunc>(getOperand(1)); } + IRInst* getIFuncThis() { return getOperand(2); } + + bool hasIFuncThis() { return getOperandCount() > 2; } + void setIFuncCall(IRFunc* func) { setOperand(1, func); } +}; + // An Instruction that creates a differential pair value from a // primal and differential. @@ -4241,6 +4253,8 @@ public: IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, UInt element); IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, IRInst* element); + IRInst* emitCoopMatMapElementFunc(IRType* type, IRInst* tuple, IRInst* func); + IRInst* emitGetElement(IRType* type, IRInst* arrayLikeType, IRIntegerValue element); IRInst* emitGetElementPtr(IRType* type, IRInst* arrayLikeType, IRIntegerValue element); diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index af29d1998..529770acc 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -2282,6 +2282,36 @@ static LegalVal legalizeGlobalParam( IRTypeLegalizationContext* context, IRGlobalParam* irGlobalParam); +static LegalVal legalizeCoopMatMapElementIFunc( + IRTypeLegalizationContext* context, + IRCoopMatMapElementIFunc* inst) +{ + // When the functor object is a lambda with no captures, + // it will be removed as a part of the legalization process, + // because it is a zero-sized struct. + // We need to explicitly remove it from its user. + if (inst->hasIFuncThis()) + { + // Check if `this` is valid. + auto legalArg = legalizeOperand(context, inst->getIFuncThis()); + if (legalArg.flavor == LegalVal::Flavor::none) + { + // If `this` is not valid, remove it from IR. + IRBuilder builder{inst}; + builder.setInsertBefore(inst); + auto newInst = builder.emitCoopMatMapElementFunc( + inst->getFullType(), + inst->getOperand(0), + inst->getOperand(1)); + + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return LegalVal::simple(newInst); + } + } + return LegalVal::simple(inst); +} + static LegalVal legalizeInst(IRTypeLegalizationContext* context, IRInst* inst) { // Any additional instructions we need to emit @@ -2293,6 +2323,9 @@ static LegalVal legalizeInst(IRTypeLegalizationContext* context, IRInst* inst) // Special-case certain operations switch (inst->getOp()) { + case kIROp_CoopMatMapElementIFunc: + return legalizeCoopMatMapElementIFunc(context, cast<IRCoopMatMapElementIFunc>(inst)); + case kIROp_Var: return legalizeLocalVar(context, cast<IRVar>(inst)); diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 20b7f795f..31742cbde 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -1506,6 +1506,82 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } } + // Creates a new function with a different parameter order. + // OpCooperativeMatrixPerElementOpNV expects a callback function to have the parameter + // in the following order: + // return-type CallbackFunction(uint, uint, coopmat1, functorThis, coopmat2, coopmat3, ...) + // But what we have by default is: + // return-type CallbackFunction(functorThis, uint, uint, coopmat1, coopmat2, coopmat3, ...) + // + // The new function will do the following, + // return-type newCallback(uint a, uint b, T mat1, TFunc f, T mat2, T mat3, ...) + // { return targetFunc(f, a, b, mat1, mat2, mat3, ...); } + IRFunc* createWrapperFunctionForPerElement(IRBuilder& builder, IRFunc* targetFunc) + { + List<IRType*> paramTypes; + for (UInt i = 0; i < targetFunc->getParamCount(); i++) + { + paramTypes.add(targetFunc->getParamType(i)); + } + + SLANG_ASSERT(paramTypes.getCount() >= 4); + + IRType* tempTypes[4]; + tempTypes[3] = builder.getPtrType(paramTypes[0]); + tempTypes[0] = paramTypes[1]; + tempTypes[1] = paramTypes[2]; + tempTypes[2] = paramTypes[3]; + paramTypes[0] = tempTypes[0]; + paramTypes[1] = tempTypes[1]; + paramTypes[2] = tempTypes[2]; + paramTypes[3] = tempTypes[3]; + + IRType* returnType = targetFunc->getDataType()->getResultType(); + + IRBuilderInsertLocScope insertLocScope(&builder); + auto wrapperFunc = builder.createFunc(); + builder.setDataType(wrapperFunc, builder.getFuncType(paramTypes, returnType)); + builder.setInsertInto(wrapperFunc); + auto block = builder.emitBlock(); + builder.setInsertInto(block); + + List<IRInst*> params; + for (Index i = 0; i < paramTypes.getCount(); i++) + { + params.add(builder.emitParam(paramTypes[i])); + } + + IRInst* tempParams[4]; + tempParams[0] = builder.emitLoad(params[3]); + tempParams[1] = params[0]; + tempParams[2] = params[1]; + tempParams[3] = params[2]; + params[0] = tempParams[0]; + params[1] = tempParams[1]; + params[2] = tempParams[2]; + params[3] = tempParams[3]; + + auto result = builder.emitCallInst(returnType, targetFunc, params); + builder.emitReturn(result); + return wrapperFunc; + } + + void processCoopMatMapElementIFunc(IRCoopMatMapElementIFunc* inst) + { + IRBuilder builder{inst}; + builder.setInsertBefore(inst); + + auto ifuncCall = inst->getIFuncCall(); + + // `this` of the functor is optional. + // Skip the synthesis if `this` is not passed. + if (ifuncCall->getParamCount() > 3) + { + auto funcSynth = createWrapperFunctionForPerElement(builder, ifuncCall); + inst->setIFuncCall(funcSynth); + } + } + void legalizeSPIRVEntryPoint(IRFunc* func, IREntryPointDecoration* entryPointDecor) { auto stage = entryPointDecor->getProfile().getStage(); @@ -1743,6 +1819,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase case kIROp_BitfieldInsert: processBitFieldOp(inst); break; + case kIROp_CoopMatMapElementIFunc: + processCoopMatMapElementIFunc(as<IRCoopMatMapElementIFunc>(inst)); + break; case kIROp_DebugValue: if (!isSimpleDataType(as<IRDebugValue>(inst)->getDebugVar()->getDataType())) diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 6d5bd9a25..85fe2fa04 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4467,6 +4467,12 @@ IRInst* IRBuilder::emitGetTupleElement(IRType* type, IRInst* tuple, UInt element return emitGetTupleElement(type, tuple, getIntValue(getIntType(), element)); } +IRInst* IRBuilder::emitCoopMatMapElementFunc(IRType* type, IRInst* tuple, IRInst* func) +{ + IRInst* args[] = {tuple, func}; + return emitIntrinsicInst(type, kIROp_CoopMatMapElementIFunc, 2, args); +} + IRInst* IRBuilder::emitMakeResultError(IRType* resultType, IRInst* errorVal) { return emitIntrinsicInst(resultType, kIROp_MakeResultError, 1, &errorVal); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 758b21e93..bc00eb531 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1191,6 +1191,13 @@ top: lowered = extractField(context, boundMemberInfo->type, base, fieldDeclRef); goto top; } + else if (auto methodDeclRef = declRef.as<CallableDecl>()) + { + auto funcVal = emitDeclRef(context, declRef, boundMemberInfo->type); + SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple); + lowered = funcVal; + goto top; + } else { @@ -4584,7 +4591,8 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> else if (auto callableDeclRef = declRef.as<CallableDecl>()) { RefPtr<BoundMemberInfo> boundMemberInfo = new BoundMemberInfo(); - boundMemberInfo->type = nullptr; + boundMemberInfo->type = + lowerType(context, getResultType(context->astBuilder, callableDeclRef)); boundMemberInfo->base = loweredBase; boundMemberInfo->declRef = callableDeclRef; |
