From f28f67d988158d6c46f7ffe967152f98d32a37b2 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 30 Jun 2025 14:32:50 -0700 Subject: Add MLP training examples. (#7550) * Add MLP training examples. * Formatting fix. * Fix. * Improve documentation on coopvector. * Improve doc. * Update doc. * Fix typo. * Cleanup shader. * Cleanup. * Fix test. * Fix type check recursion. * Fix. * Fix. * Fix override check. --- source/slang/core.meta.slang | 4 +- source/slang/hlsl.meta.slang | 277 +++++++++++++++++++++++++++++--- source/slang/slang-ast-modifier.h | 7 + source/slang/slang-check-conversion.cpp | 11 +- source/slang/slang-check-decl.cpp | 28 ++-- source/slang/slang-check-expr.cpp | 5 +- source/slang/slang-check-shader.cpp | 10 ++ 7 files changed, 301 insertions(+), 41 deletions(-) (limited to 'source') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index deaeae439..2e56c1082 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1809,12 +1809,12 @@ extension Ptr __init(NativeString nativeStr) { this = nativeStr.getBuffer(); } __generic - __intrinsic_op(0) + __intrinsic_op($(kIROp_BitCast)) __implicit_conversion($(kConversionCost_PtrToVoidPtr)) __init(Ptr ptr); __generic - __intrinsic_op(0) + __intrinsic_op($(kIROp_BitCast)) __implicit_conversion($(kConversionCost_PtrToVoidPtr)) __init(NativeRef ptr); } diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 32d7ea824..38f274984 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -26610,16 +26610,24 @@ CoopVec coopVecMatMulPacked( } } -/// Multiply a cooperative vector with a matrix. -/// @param input The input cooperative vector to multiply with the matrix. +/// Multiply a matrix with a cooperative vector. Given a M-row by K-col `matrix`, and a K-element column vector `input`, computes `matrix * input`, and +/// returns a M-element vector. +/// @param input The K-element input cooperative vector to multiply with the matrix. /// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as 8-bit integers, 16-bit floats, etc). -/// @param matrix The matrix to multiply with the input vector. +/// @param matrix The M-by-K matrix to multiply with the input vector. /// @param matrixOffset Byte offset into the matrix buffer. /// @param matrixInterpretation Specifies how to interpret the values in the matrix (e.g. as 8-bit integers, 16-bit floats, etc). /// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major). /// @param transpose Whether to transpose the matrix before multiplication. /// @param matrixStride The stride in bytes between rows/columns of the matrix. /// @return A new cooperative vector containing the result of the matrix multiplication. +/// @remarks Depending on target hardware, some combinations of `inputInterpretation`, `matrixInterpretation` and `memoryLayout` may not be supported. +/// For example, CoopVecComponentType.Float32 is not widely supported. Developers should query device properties through the host graphics API to +/// find out which interpretations are supported. +/// +/// Transposing is not supported when `memoryLayout` is `RowMajor` or `ColumnMajor`, and `transpose` must be `false`. +/// Not all component types support transposing. +/// When `memoryLayout` is `InferencingOptimal` or `TrainingOptimal`, `matrixStride` is ignored. [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] @@ -26650,7 +26658,9 @@ CoopVec coopVecMatMul( matrixStride); } -/// Multiply a cooperative vector with a matrix and add a bias vector. +/// Multiply a matrix with a cooperative vector and add a bias vector to the result. +/// Given a M-row by K-col `matrix`, a K-element column vector `input`, and a M-element vector `bias`, computes `matrix * input + bias`, and +/// returns a M-element vector. /// @param input The input cooperative vector to multiply with the matrix. /// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as packed values). /// @param k The number of columns in the matrix. @@ -26667,6 +26677,14 @@ CoopVec coopVecMatMul( /// @remarks Unlike coopVecMatMulAdd, this function supports packed input interpretations where multiple values /// can be packed into each element of the input vector. The k parameter specifies the actual number of /// values to use from the packed input. +/// +/// Depending on target hardware, some combinations of `inputInterpretation`, `matrixInterpretation` and `memoryLayout` may not be supported. +/// For example, CoopVecComponentType.Float32 is not widely supported. Developers should query device properties through the host graphics API to +/// find out which interpretations are supported. +/// +/// Transposing is not supported when `memoryLayout` is `RowMajor` or `ColumnMajor`, and `transpose` must be `false`. +/// Not all component types support transposing. +/// When `memoryLayout` is `InferencingOptimal` or `TrainingOptimal`, `matrixStride` is ignored. [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] @@ -26804,7 +26822,9 @@ CoopVec coopVecMatMulAddPacked coopVecMatMulAddPacked CoopVec coopVecMatMulPacked( @@ -27288,6 +27323,185 @@ void coopVecReduceSumAccumulate +CoopVec coopVecMatMulPacked( + CoopVec input, + constexpr CoopVecComponentType inputInterpretation, + constexpr int k, + void* matrixPtr, + constexpr CoopVecComponentType matrixInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK + , "for non-packed inputInterpretation values k must be equal to the input vector length"); + static_assert(!__isPackedInputInterpretation(inputInterpretation) + || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); + __target_switch + { + case spirv: + let m : int32_t = M; + let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation); + let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); + let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); + int operands = 0; // NoneKHR + let zero = 0; + let cvtMatPtr = (Ptr)matrixPtr; + if (__isSignedInt()) + { + operands |= 0x08; // MatrixResultSignedComponentsKHR + } + return spirv_asm + { + result:$$CoopVec = OpCooperativeVectorMatrixMulNV $input $inputInterpretationSpirv $cvtMatPtr $zero $matrixInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands; + }; + } +} + +// specialized coopVecMatMul for non-packed inputs +[require(spirv, cooperative_vector)] +__generic +CoopVec coopVecMatMul( + CoopVec input, + constexpr CoopVecComponentType inputInterpretation, + void* matrix, + constexpr CoopVecComponentType matrixInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(!__isPackedInputInterpretation(inputInterpretation) + , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually"); + return coopVecMatMulPacked< + T, M, K, U>( + input, + inputInterpretation, + K, + matrix, + matrixInterpretation, + memoryLayout, + transpose, + matrixStride); +} + +[require(spirv, cooperative_vector)] +CoopVec coopVecMatMulAddPacked( + CoopVec input, + constexpr CoopVecComponentType inputInterpretation, + constexpr int k, + void* matrixPtr, + constexpr CoopVecComponentType matrixInterpretation, + void* biasPtr, + constexpr CoopVecComponentType biasInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK + , "for non-packed inputInterpretation values k must be equal to the input vector length"); + static_assert(!__isPackedInputInterpretation(inputInterpretation) + || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); + + __target_switch + { + case spirv: + let m : int32_t = M; + let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation); + let biasInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(biasInterpretation); + let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); + let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); + let zero : int32_t = 0; + let cvtMatPtr = (Ptr)matrixPtr; + let cvtBiasPtr = (Ptr)biasPtr; + return spirv_asm + { + result:$$CoopVec = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $cvtMatPtr $zero $matrixInterpretationSpirv $cvtBiasPtr $zero $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride; + }; + } +} + +[require(spirv, cooperative_vector)] +CoopVec coopVecMatMulAdd( + CoopVec input, + constexpr CoopVecComponentType inputInterpretation, + void* matrix, + constexpr CoopVecComponentType matrixInterpretation, + void* bias, + constexpr CoopVecComponentType biasInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(!__isPackedInputInterpretation(inputInterpretation) + , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually"); + return coopVecMatMulAddPacked< + T, M, K, U>( + input, + inputInterpretation, + K, + matrix, + matrixInterpretation, + bias, + biasInterpretation, + memoryLayout, + transpose, + matrixStride); +} + +[require(spirv, cooperative_vector_training)] +void coopVecOuterProductAccumulate( + CoopVec a, + CoopVec b, + void* matrixPtr, + constexpr uint matrixStride, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr CoopVecComponentType matrixInterpretation, +) +{ + let zero : int32_t = 0; + __target_switch + { + case spirv: + let matrixInterpretationSpirv : int = __getSpvCoopVecComponentType(matrixInterpretation); + let memoryLayoutSpirv : int = __getSpvCoopVecMatrixLayout(memoryLayout); + let cvtMatrixPtr = (Ptr)matrixPtr; + spirv_asm + { + OpCapability CooperativeVectorTrainingNV; + OpCooperativeVectorOuterProductAccumulateNV $cvtMatrixPtr $zero $a $b $memoryLayoutSpirv $matrixInterpretationSpirv $matrixStride; + }; + } +} + +[require(spirv, cooperative_vector_training)] +void coopVecReduceSumAccumulate( + CoopVec v, + void* buffer +) +{ + let zero : int32_t = 0; + let bufferPtr = (Ptr)(buffer); + __target_switch + { + case spirv: + spirv_asm + { + OpCapability CooperativeVectorTrainingNV; + OpCooperativeVectorReduceSumAccumulateNV $bufferPtr $zero $v; + }; + } +} + //@public: /// Mark a variable as being workgroup uniform. @@ -28126,3 +28340,24 @@ uint packHalf2x16(half2 unpackedValue) { return packHalf2x16(float2(unpackedValue)); } + +[require(spirv)] +void InterlockedAddF16Emulated(half* dest, half value, out half originalValue) +{ + let buf = (half2*)(dest); + uint64_t byteAddress = (uint64_t)dest; + if ((byteAddress & 3) == 0) + { + originalValue = __atomic_add(*buf, half2(value, half(0.0))).x; + } + else + { + originalValue = __atomic_add(*buf, half2(half(0.0), value)).y; + } +} + +[require(spirv)] +void InterlockedAddF16x2(half2* dest, half2 value, out half2 originalValue) +{ + originalValue = __atomic_add(*dest, value); +} \ No newline at end of file diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 88dea0b7e..56f9d873a 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -187,6 +187,13 @@ class SynthesizedStaticLambdaFuncModifier : public Modifier FIDDLE(...) }; +FIDDLE() +class ExplicitlyDeclaredCapabilityModifier : public Modifier +{ + FIDDLE(...) + FIDDLE() CapabilitySet declaredCapabilityRequirements; +}; + // Marks a synthesized variable as local temporary variable. FIDDLE() class LocalTempVarModifier : public Modifier diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 506abc1be..6456dbe98 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -419,6 +419,9 @@ bool SemanticsVisitor::createInvokeExprForSynthesizedCtor( if (!structDecl) return false; + if (!structDecl->checkState.isBeingChecked()) + ensureDecl(structDecl, DeclCheckState::AttributesChecked); + HashSet isVisit; bool isCStyle = false; if (!_getSynthesizedConstructor( @@ -656,8 +659,8 @@ bool SemanticsVisitor::_readAggregateValueFromInitializerList( auto toMakeArrayFromElementExpr = m_astBuilder->create(); toMakeArrayFromElementExpr->loc = fromInitializerListExpr->loc; toMakeArrayFromElementExpr->type = QualType(toType); - - *outToExpr = toMakeArrayFromElementExpr; + if (outToExpr) + *outToExpr = toMakeArrayFromElementExpr; return true; } for (UInt ee = 0; ee < elementCount; ++ee) @@ -748,8 +751,8 @@ bool SemanticsVisitor::_readAggregateValueFromInitializerList( auto defaultConstructExpr = m_astBuilder->create(); defaultConstructExpr->loc = fromInitializerListExpr->loc; defaultConstructExpr->type = QualType(toType); - - *outToExpr = defaultConstructExpr; + if (outToExpr) + *outToExpr = defaultConstructExpr; return true; } diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 1a70e25d7..0dd859bb2 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2914,9 +2914,9 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness context->parentDecl->findLastDirectMemberDeclOfName(requirementDeclRef.getName())) { // Remove the `ToBeSynthesizedModifier`. - if (as(existingDecl->modifiers.first)) + if (auto mod = existingDecl->modifiers.findModifier()) { - existingDecl->modifiers.first = existingDecl->modifiers.first->next; + removeModifier(existingDecl, mod); } else { @@ -3133,14 +3133,9 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness addModifier(aggTypeDecl, m_astBuilder->create()); - // The visibility of synthesized decl should be the min of the parent decl and the requirement. - if (requirementDeclRef.getDecl()->findModifier()) - { - auto requirementVisibility = getDeclVisibility(requirementDeclRef.getDecl()); - auto thisVisibility = getDeclVisibility(context->parentDecl); - auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(aggTypeDecl, visibility); - } + // The visibility of synthesized decl should be the same of the parent decl. + auto thisVisibility = getDeclVisibility(context->parentDecl); + addVisibilityModifier(aggTypeDecl, thisVisibility); // Synthesize the rest of IDifferential method conformances by recursively checking // conformance on the synthesized decl. @@ -4149,8 +4144,12 @@ bool SemanticsVisitor::doesVarMatchRequirement( return false; } - auto satisfyingVal = - tryConstantFoldDeclRef(satisfyingMemberDeclRef, ConstantFoldingKind::LinkTime, nullptr); + IntVal* satisfyingVal = nullptr; + if (isValidCompileTimeConstantType(satisfyingType)) + { + satisfyingVal = + tryConstantFoldDeclRef(satisfyingMemberDeclRef, ConstantFoldingKind::LinkTime, nullptr); + } if (satisfyingVal) { witnessTable->add(requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingVal)); @@ -5125,9 +5124,9 @@ void SemanticsVisitor::markOverridingDecl( return; } + memberDecl = maybeGetInner(memberDecl); if (hasDefaultImpl(requiredMemberDeclRef)) { - memberDecl = maybeGetInner(memberDecl); // If the required member has a default implementation, // we need to make sure the member we found is marked as 'override'. if (!memberDecl->hasModifier()) @@ -14290,6 +14289,9 @@ void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* fun } else { + auto declaredCapModifier = m_astBuilder->create(); + declaredCapModifier->declaredCapabilityRequirements = declaredCaps; + addModifier(funcDecl, declaredCapModifier); if (vis == DeclVisibility::Public) { // For public decls, we need to enforce that the function diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 306687bd8..b90081af8 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -682,6 +682,8 @@ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult( auto typeDef = m_astBuilder->create(); typeDef->nameAndLoc.name = getName("Differential"); typeDef->parentDecl = structDecl; + addVisibilityModifier(structDecl, getDeclVisibility(parent)); + addVisibilityModifier(typeDef, getDeclVisibility(parent)); auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl)); @@ -714,6 +716,7 @@ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult( typeDef->type.type = calcThisType(subType->getDeclRef().getDecl()->getDefaultDeclRef()); + addVisibilityModifier(typeDef, getDeclVisibility(parent)); synthesizedDecl = parent; parent->addDirectMemberDecl(typeDef); @@ -2085,7 +2088,7 @@ IntVal* SemanticsVisitor::tryConstantFoldDeclRef( // to not allow such cases. // // Note that float-to-inst casts for non-`IntVal`s are allowed. - if (!isScalarIntegerType(decl->getType())) + if (!isValidCompileTimeConstantType(decl->getType())) { getSink()->diagnose(declRef, Diagnostics::intValFromNonIntSpecConstEncountered); return nullptr; diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index e6744071b..a360361f7 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -553,6 +553,16 @@ void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) targetOptionSet.hasOption(CompilerOptionName::Capability) && (targetOptionSet.getIntOption(CompilerOptionName::Capability) != SLANG_CAPABILITY_UNKNOWN); + + if (auto declaredCapsMod = + entryPointFuncDecl->findModifier()) + { + // If the entry point has an explicitly declared capability, then we + // will merge that with the target capability set before checking if + // there is an implicit upgrade. + targetCaps.nonDestructiveJoin(declaredCapsMod->declaredCapabilityRequirements); + } + // Only attempt to error if a specific profile or capability is requested if ((specificCapabilityRequested || specificProfileRequested) && targetCaps.atLeastOneSetImpliedInOther( -- cgit v1.2.3