diff options
| author | Yong He <yonghe@outlook.com> | 2025-06-30 14:32:50 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-06-30 21:32:50 +0000 |
| commit | f28f67d988158d6c46f7ffe967152f98d32a37b2 (patch) | |
| tree | 2aa620986a87ec69cf1f210c714312e42b62ac9e /source | |
| parent | a55ff722cae338a8fcf5402858c47cf0650a8e5e (diff) | |
Add MLP training examples. (#7550)
* Add MLP training examples.
* Formatting fix.
* Fix.
* Improve documentation on coopvector.
* Improve doc.
* Update doc.
* Fix typo.
* Cleanup shader.
* Cleanup.
* Fix test.
* Fix type check recursion.
* Fix.
* Fix.
* Fix override check.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 4 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 277 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-check-conversion.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 28 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 10 |
7 files changed, 301 insertions, 41 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index deaeae439..2e56c1082 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1809,12 +1809,12 @@ extension Ptr<void> __init(NativeString nativeStr) { this = nativeStr.getBuffer(); } __generic<T, let addrSpace : uint64_t> - __intrinsic_op(0) + __intrinsic_op($(kIROp_BitCast)) __implicit_conversion($(kConversionCost_PtrToVoidPtr)) __init(Ptr<T, addrSpace> ptr); __generic<T> - __intrinsic_op(0) + __intrinsic_op($(kIROp_BitCast)) __implicit_conversion($(kConversionCost_PtrToVoidPtr)) __init(NativeRef<T> ptr); } diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 32d7ea824..38f274984 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -26610,16 +26610,24 @@ CoopVec<T, M> coopVecMatMulPacked( } } -/// Multiply a cooperative vector with a matrix. -/// @param input The input cooperative vector to multiply with the matrix. +/// Multiply a matrix with a cooperative vector. Given a M-row by K-col `matrix`, and a K-element column vector `input`, computes `matrix * input`, and +/// returns a M-element vector. +/// @param input The K-element input cooperative vector to multiply with the matrix. /// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as 8-bit integers, 16-bit floats, etc). -/// @param matrix The matrix to multiply with the input vector. +/// @param matrix The M-by-K matrix to multiply with the input vector. /// @param matrixOffset Byte offset into the matrix buffer. /// @param matrixInterpretation Specifies how to interpret the values in the matrix (e.g. as 8-bit integers, 16-bit floats, etc). /// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major). /// @param transpose Whether to transpose the matrix before multiplication. /// @param matrixStride The stride in bytes between rows/columns of the matrix. /// @return A new cooperative vector containing the result of the matrix multiplication. +/// @remarks Depending on target hardware, some combinations of `inputInterpretation`, `matrixInterpretation` and `memoryLayout` may not be supported. +/// For example, CoopVecComponentType.Float32 is not widely supported. Developers should query device properties through the host graphics API to +/// find out which interpretations are supported. +/// +/// Transposing is not supported when `memoryLayout` is `RowMajor` or `ColumnMajor`, and `transpose` must be `false`. +/// Not all component types support transposing. +/// When `memoryLayout` is `InferencingOptimal` or `TrainingOptimal`, `matrixStride` is ignored. [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] @@ -26650,7 +26658,9 @@ CoopVec<T, M> coopVecMatMul( matrixStride); } -/// Multiply a cooperative vector with a matrix and add a bias vector. +/// Multiply a matrix with a cooperative vector and add a bias vector to the result. +/// Given a M-row by K-col `matrix`, a K-element column vector `input`, and a M-element vector `bias`, computes `matrix * input + bias`, and +/// returns a M-element vector. /// @param input The input cooperative vector to multiply with the matrix. /// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as packed values). /// @param k The number of columns in the matrix. @@ -26667,6 +26677,14 @@ CoopVec<T, M> coopVecMatMul( /// @remarks Unlike coopVecMatMulAdd, this function supports packed input interpretations where multiple values /// can be packed into each element of the input vector. The k parameter specifies the actual number of /// values to use from the packed input. +/// +/// Depending on target hardware, some combinations of `inputInterpretation`, `matrixInterpretation` and `memoryLayout` may not be supported. +/// For example, CoopVecComponentType.Float32 is not widely supported. Developers should query device properties through the host graphics API to +/// find out which interpretations are supported. +/// +/// Transposing is not supported when `memoryLayout` is `RowMajor` or `ColumnMajor`, and `transpose` must be `false`. +/// Not all component types support transposing. +/// When `memoryLayout` is `InferencingOptimal` or `TrainingOptimal`, `matrixStride` is ignored. [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] @@ -26804,7 +26822,9 @@ CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, l } } -/// Multiply a cooperative vector with a matrix and add a bias vector. +/// Multiply a matrix with a cooperative vector and add a bias vector. +/// Given a M-row by K-col `matrix`, a K-element column vector `input`, and a M-element vector `bias`, computes `matrix * input + bias`, and +/// returns a M-element vector. /// @param input The input cooperative vector to multiply with the matrix. /// @param inputInterpretation Specifies how to interpret the values in the input vector. /// @param matrix The matrix buffer to multiply with. @@ -26817,6 +26837,13 @@ CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, l /// @param transpose Whether to transpose the matrix before multiplication. /// @param matrixStride The stride between matrix rows/columns in bytes. /// @return A new cooperative vector containing the result of the matrix multiplication plus bias. +/// @remarks Depending on target hardware, some combinations of `inputInterpretation`, `matrixInterpretation` and `memoryLayout` may not be supported. +/// For example, CoopVecComponentType.Float32 is not widely supported. Developers should query device properties through the host graphics API to +/// find out which interpretations are supported. +/// +/// Transposing is not supported when `memoryLayout` is `RowMajor` or `ColumnMajor`, and `transpose` must be `false`. +/// Not all component types support transposing. +/// When `memoryLayout` is `InferencingOptimal` or `TrainingOptimal`, `matrixStride` is ignored. [ForceInline] [require(cooperative_vector)] [require(hlsl_coopvec_poc)] @@ -26862,7 +26889,9 @@ ${{{{ if(buffer.isRW) { }}}} -/// Accumulate the outer product of two cooperative vectors into a matrix. +/// Atomically accumulates the outer product of two cooperative vectors into a matrix. Given an M-element vector `a`, and an N-element vector `b`, +/// compute the outer product of `a` and `b`, forming a M-row by N-col matrix. The elements in the matrix is then atomically accumulated +/// to memory location represented by `matrix`. /// @param a The first cooperative vector. /// @param b The second cooperative vector. /// @param matrix The matrix buffer to accumulate the result into. @@ -26870,6 +26899,21 @@ if(buffer.isRW) /// @param matrixStride The stride between matrix rows/columns in bytes. /// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major). /// @param matrixInterpretation Specifies how to interpret the values in the matrix. +/// @remarks On current hardware, `memoryLayout` must be `TrainingOptimal`. +/// +/// When `memoryLayout` is `RowMajor`, this function is equivalent to: +/// +/// ``` +/// uint8_t* matrixPtr = matrix + matrixOffset; +/// for (int i = 0; i < M; i++) +/// { +/// for (int j = 0; j < N; j++) +/// { +/// let elem = a[i] * b[j]; +/// atomicAdd(matrixPtr + i * matrixStride + j * sizeof(T), elem); +/// } +/// } +/// ``` [require(cooperative_vector)] [require(hlsl_coopvec_poc)] [require(optix_coopvec)] @@ -26959,10 +27003,15 @@ void coopVecOuterProductAccumulate<T : __BuiltinArithmeticType, let M : int, let } } -/// Accumulate the sum of a cooperative vector into a buffer at the specified offset. +/// Atomically accumulates the elements a cooperative vector into a buffer at the specified offset. /// @param v The cooperative vector to sum. /// @param buffer The buffer to accumulate the sum into. /// @param offset Byte offset into the buffer. +/// @remarks This function is equivalent to: +/// ``` +/// for (int i = 0; i < N; i++) +/// atomicAdd(dest[i], v[i]); +/// ``` [require(cooperative_vector)] [require(hlsl_coopvec_poc)] [require(optix_coopvec)] @@ -27015,20 +27064,6 @@ static const struct { for(auto buffer : kStructuredBufferCases_) { }}}} -/// Multiply a cooperative vector with a matrix and return the result. -/// @param input The input cooperative vector to multiply with the matrix. -/// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as packed values). -/// @param k The number of columns in the matrix. -/// @param matrix The matrix buffer to multiply with. -/// @param matrixOffset Byte offset into the matrix buffer. -/// @param matrixInterpretation Specifies how to interpret the values in the matrix. -/// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major). -/// @param transpose Whether to transpose the matrix before multiplication. -/// @param matrixStride The stride between matrix rows/columns in bytes. -/// @return A new cooperative vector containing the result of the matrix multiplication. -/// @remarks Unlike coopVecMatMul, this function supports packed input interpretations where multiple values -/// can be packed into each element of the input vector. The k parameter specifies the actual number of -/// values to use from the packed input. [require(spirv, cooperative_vector)] __generic<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType,IgnoredBufferElementType> CoopVec<T, M> coopVecMatMulPacked( @@ -27288,6 +27323,185 @@ void coopVecReduceSumAccumulate<T : __BuiltinArithmeticType, let N : int, U, let } } +// Pointer overloads for coopvector operations. + +[require(spirv, cooperative_vector)] +__generic<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType> +CoopVec<T, M> coopVecMatMulPacked( + CoopVec<U, PackedK> input, + constexpr CoopVecComponentType inputInterpretation, + constexpr int k, + void* matrixPtr, + constexpr CoopVecComponentType matrixInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK + , "for non-packed inputInterpretation values k must be equal to the input vector length"); + static_assert(!__isPackedInputInterpretation(inputInterpretation) + || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); + __target_switch + { + case spirv: + let m : int32_t = M; + let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation); + let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); + let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); + int operands = 0; // NoneKHR + let zero = 0; + let cvtMatPtr = (Ptr<T[]>)matrixPtr; + if (__isSignedInt<T>()) + { + operands |= 0x08; // MatrixResultSignedComponentsKHR + } + return spirv_asm + { + result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulNV $input $inputInterpretationSpirv $cvtMatPtr $zero $matrixInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands; + }; + } +} + +// specialized coopVecMatMul for non-packed inputs +[require(spirv, cooperative_vector)] +__generic<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType> +CoopVec<T, M> coopVecMatMul( + CoopVec<U, K> input, + constexpr CoopVecComponentType inputInterpretation, + void* matrix, + constexpr CoopVecComponentType matrixInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(!__isPackedInputInterpretation(inputInterpretation) + , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually"); + return coopVecMatMulPacked< + T, M, K, U>( + input, + inputInterpretation, + K, + matrix, + matrixInterpretation, + memoryLayout, + transpose, + matrixStride); +} + +[require(spirv, cooperative_vector)] +CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType>( + CoopVec<U, PackedK> input, + constexpr CoopVecComponentType inputInterpretation, + constexpr int k, + void* matrixPtr, + constexpr CoopVecComponentType matrixInterpretation, + void* biasPtr, + constexpr CoopVecComponentType biasInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK + , "for non-packed inputInterpretation values k must be equal to the input vector length"); + static_assert(!__isPackedInputInterpretation(inputInterpretation) + || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK + , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor"); + + __target_switch + { + case spirv: + let m : int32_t = M; + let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation); + let biasInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(biasInterpretation); + let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation); + let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout); + let zero : int32_t = 0; + let cvtMatPtr = (Ptr<T[]>)matrixPtr; + let cvtBiasPtr = (Ptr<T[]>)biasPtr; + return spirv_asm + { + result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $cvtMatPtr $zero $matrixInterpretationSpirv $cvtBiasPtr $zero $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride; + }; + } +} + +[require(spirv, cooperative_vector)] +CoopVec<T, M> coopVecMatMulAdd<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType>( + CoopVec<U, K> input, + constexpr CoopVecComponentType inputInterpretation, + void* matrix, + constexpr CoopVecComponentType matrixInterpretation, + void* bias, + constexpr CoopVecComponentType biasInterpretation, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr bool transpose, + constexpr uint matrixStride +) +{ + static_assert(!__isPackedInputInterpretation(inputInterpretation) + , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually"); + return coopVecMatMulAddPacked< + T, M, K, U>( + input, + inputInterpretation, + K, + matrix, + matrixInterpretation, + bias, + biasInterpretation, + memoryLayout, + transpose, + matrixStride); +} + +[require(spirv, cooperative_vector_training)] +void coopVecOuterProductAccumulate<T : __BuiltinArithmeticType, let M : int, let N : int>( + CoopVec<T, M> a, + CoopVec<T, N> b, + void* matrixPtr, + constexpr uint matrixStride, + constexpr CoopVecMatrixLayout memoryLayout, + constexpr CoopVecComponentType matrixInterpretation, +) +{ + let zero : int32_t = 0; + __target_switch + { + case spirv: + let matrixInterpretationSpirv : int = __getSpvCoopVecComponentType(matrixInterpretation); + let memoryLayoutSpirv : int = __getSpvCoopVecMatrixLayout(memoryLayout); + let cvtMatrixPtr = (Ptr<T[]>)matrixPtr; + spirv_asm + { + OpCapability CooperativeVectorTrainingNV; + OpCooperativeVectorOuterProductAccumulateNV $cvtMatrixPtr $zero $a $b $memoryLayoutSpirv $matrixInterpretationSpirv $matrixStride; + }; + } +} + +[require(spirv, cooperative_vector_training)] +void coopVecReduceSumAccumulate<T : __BuiltinArithmeticType, let N : int>( + CoopVec<T, N> v, + void* buffer +) +{ + let zero : int32_t = 0; + let bufferPtr = (Ptr<T[]>)(buffer); + __target_switch + { + case spirv: + spirv_asm + { + OpCapability CooperativeVectorTrainingNV; + OpCooperativeVectorReduceSumAccumulateNV $bufferPtr $zero $v; + }; + } +} + //@public: /// Mark a variable as being workgroup uniform. @@ -28126,3 +28340,24 @@ uint packHalf2x16(half2 unpackedValue) { return packHalf2x16(float2(unpackedValue)); } + +[require(spirv)] +void InterlockedAddF16Emulated(half* dest, half value, out half originalValue) +{ + let buf = (half2*)(dest); + uint64_t byteAddress = (uint64_t)dest; + if ((byteAddress & 3) == 0) + { + originalValue = __atomic_add(*buf, half2(value, half(0.0))).x; + } + else + { + originalValue = __atomic_add(*buf, half2(half(0.0), value)).y; + } +} + +[require(spirv)] +void InterlockedAddF16x2(half2* dest, half2 value, out half2 originalValue) +{ + originalValue = __atomic_add(*dest, value); +}
\ No newline at end of file diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 88dea0b7e..56f9d873a 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -187,6 +187,13 @@ class SynthesizedStaticLambdaFuncModifier : public Modifier FIDDLE(...) }; +FIDDLE() +class ExplicitlyDeclaredCapabilityModifier : public Modifier +{ + FIDDLE(...) + FIDDLE() CapabilitySet declaredCapabilityRequirements; +}; + // Marks a synthesized variable as local temporary variable. FIDDLE() class LocalTempVarModifier : public Modifier diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 506abc1be..6456dbe98 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -419,6 +419,9 @@ bool SemanticsVisitor::createInvokeExprForSynthesizedCtor( if (!structDecl) return false; + if (!structDecl->checkState.isBeingChecked()) + ensureDecl(structDecl, DeclCheckState::AttributesChecked); + HashSet<Type*> isVisit; bool isCStyle = false; if (!_getSynthesizedConstructor( @@ -656,8 +659,8 @@ bool SemanticsVisitor::_readAggregateValueFromInitializerList( auto toMakeArrayFromElementExpr = m_astBuilder->create<MakeArrayFromElementExpr>(); toMakeArrayFromElementExpr->loc = fromInitializerListExpr->loc; toMakeArrayFromElementExpr->type = QualType(toType); - - *outToExpr = toMakeArrayFromElementExpr; + if (outToExpr) + *outToExpr = toMakeArrayFromElementExpr; return true; } for (UInt ee = 0; ee < elementCount; ++ee) @@ -748,8 +751,8 @@ bool SemanticsVisitor::_readAggregateValueFromInitializerList( auto defaultConstructExpr = m_astBuilder->create<DefaultConstructExpr>(); defaultConstructExpr->loc = fromInitializerListExpr->loc; defaultConstructExpr->type = QualType(toType); - - *outToExpr = defaultConstructExpr; + if (outToExpr) + *outToExpr = defaultConstructExpr; return true; } diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 1a70e25d7..0dd859bb2 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2914,9 +2914,9 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness context->parentDecl->findLastDirectMemberDeclOfName(requirementDeclRef.getName())) { // Remove the `ToBeSynthesizedModifier`. - if (as<ToBeSynthesizedModifier>(existingDecl->modifiers.first)) + if (auto mod = existingDecl->modifiers.findModifier<ToBeSynthesizedModifier>()) { - existingDecl->modifiers.first = existingDecl->modifiers.first->next; + removeModifier(existingDecl, mod); } else { @@ -3133,14 +3133,9 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness addModifier(aggTypeDecl, m_astBuilder->create<SynthesizedModifier>()); - // The visibility of synthesized decl should be the min of the parent decl and the requirement. - if (requirementDeclRef.getDecl()->findModifier<VisibilityModifier>()) - { - auto requirementVisibility = getDeclVisibility(requirementDeclRef.getDecl()); - auto thisVisibility = getDeclVisibility(context->parentDecl); - auto visibility = Math::Min(thisVisibility, requirementVisibility); - addVisibilityModifier(aggTypeDecl, visibility); - } + // The visibility of synthesized decl should be the same of the parent decl. + auto thisVisibility = getDeclVisibility(context->parentDecl); + addVisibilityModifier(aggTypeDecl, thisVisibility); // Synthesize the rest of IDifferential method conformances by recursively checking // conformance on the synthesized decl. @@ -4149,8 +4144,12 @@ bool SemanticsVisitor::doesVarMatchRequirement( return false; } - auto satisfyingVal = - tryConstantFoldDeclRef(satisfyingMemberDeclRef, ConstantFoldingKind::LinkTime, nullptr); + IntVal* satisfyingVal = nullptr; + if (isValidCompileTimeConstantType(satisfyingType)) + { + satisfyingVal = + tryConstantFoldDeclRef(satisfyingMemberDeclRef, ConstantFoldingKind::LinkTime, nullptr); + } if (satisfyingVal) { witnessTable->add(requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingVal)); @@ -5125,9 +5124,9 @@ void SemanticsVisitor::markOverridingDecl( return; } + memberDecl = maybeGetInner(memberDecl); if (hasDefaultImpl(requiredMemberDeclRef)) { - memberDecl = maybeGetInner(memberDecl); // If the required member has a default implementation, // we need to make sure the member we found is marked as 'override'. if (!memberDecl->hasModifier<OverrideModifier>()) @@ -14290,6 +14289,9 @@ void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* fun } else { + auto declaredCapModifier = m_astBuilder->create<ExplicitlyDeclaredCapabilityModifier>(); + declaredCapModifier->declaredCapabilityRequirements = declaredCaps; + addModifier(funcDecl, declaredCapModifier); if (vis == DeclVisibility::Public) { // For public decls, we need to enforce that the function diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 306687bd8..b90081af8 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -682,6 +682,8 @@ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult( auto typeDef = m_astBuilder->create<TypeAliasDecl>(); typeDef->nameAndLoc.name = getName("Differential"); typeDef->parentDecl = structDecl; + addVisibilityModifier(structDecl, getDeclVisibility(parent)); + addVisibilityModifier(typeDef, getDeclVisibility(parent)); auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl)); @@ -714,6 +716,7 @@ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult( typeDef->type.type = calcThisType(subType->getDeclRef().getDecl()->getDefaultDeclRef()); + addVisibilityModifier(typeDef, getDeclVisibility(parent)); synthesizedDecl = parent; parent->addDirectMemberDecl(typeDef); @@ -2085,7 +2088,7 @@ IntVal* SemanticsVisitor::tryConstantFoldDeclRef( // to not allow such cases. // // Note that float-to-inst casts for non-`IntVal`s are allowed. - if (!isScalarIntegerType(decl->getType())) + if (!isValidCompileTimeConstantType(decl->getType())) { getSink()->diagnose(declRef, Diagnostics::intValFromNonIntSpecConstEncountered); return nullptr; diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index e6744071b..a360361f7 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -553,6 +553,16 @@ void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink) targetOptionSet.hasOption(CompilerOptionName::Capability) && (targetOptionSet.getIntOption(CompilerOptionName::Capability) != SLANG_CAPABILITY_UNKNOWN); + + if (auto declaredCapsMod = + entryPointFuncDecl->findModifier<ExplicitlyDeclaredCapabilityModifier>()) + { + // If the entry point has an explicitly declared capability, then we + // will merge that with the target capability set before checking if + // there is an implicit upgrade. + targetCaps.nonDestructiveJoin(declaredCapsMod->declaredCapabilityRequirements); + } + // Only attempt to error if a specific profile or capability is requested if ((specificCapabilityRequested || specificProfileRequested) && targetCaps.atLeastOneSetImpliedInOther( |
