summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-06-30 14:32:50 -0700
committerGitHub <noreply@github.com>2025-06-30 21:32:50 +0000
commitf28f67d988158d6c46f7ffe967152f98d32a37b2 (patch)
tree2aa620986a87ec69cf1f210c714312e42b62ac9e /source
parenta55ff722cae338a8fcf5402858c47cf0650a8e5e (diff)
Add MLP training examples. (#7550)
* Add MLP training examples. * Formatting fix. * Fix. * Improve documentation on coopvector. * Improve doc. * Update doc. * Fix typo. * Cleanup shader. * Cleanup. * Fix test. * Fix type check recursion. * Fix. * Fix. * Fix override check.
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang4
-rw-r--r--source/slang/hlsl.meta.slang277
-rw-r--r--source/slang/slang-ast-modifier.h7
-rw-r--r--source/slang/slang-check-conversion.cpp11
-rw-r--r--source/slang/slang-check-decl.cpp28
-rw-r--r--source/slang/slang-check-expr.cpp5
-rw-r--r--source/slang/slang-check-shader.cpp10
7 files changed, 301 insertions, 41 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index deaeae439..2e56c1082 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -1809,12 +1809,12 @@ extension Ptr<void>
__init(NativeString nativeStr) { this = nativeStr.getBuffer(); }
__generic<T, let addrSpace : uint64_t>
- __intrinsic_op(0)
+ __intrinsic_op($(kIROp_BitCast))
__implicit_conversion($(kConversionCost_PtrToVoidPtr))
__init(Ptr<T, addrSpace> ptr);
__generic<T>
- __intrinsic_op(0)
+ __intrinsic_op($(kIROp_BitCast))
__implicit_conversion($(kConversionCost_PtrToVoidPtr))
__init(NativeRef<T> ptr);
}
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 32d7ea824..38f274984 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -26610,16 +26610,24 @@ CoopVec<T, M> coopVecMatMulPacked(
}
}
-/// Multiply a cooperative vector with a matrix.
-/// @param input The input cooperative vector to multiply with the matrix.
+/// Multiply a matrix with a cooperative vector. Given a M-row by K-col `matrix`, and a K-element column vector `input`, computes `matrix * input`, and
+/// returns a M-element vector.
+/// @param input The K-element input cooperative vector to multiply with the matrix.
/// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as 8-bit integers, 16-bit floats, etc).
-/// @param matrix The matrix to multiply with the input vector.
+/// @param matrix The M-by-K matrix to multiply with the input vector.
/// @param matrixOffset Byte offset into the matrix buffer.
/// @param matrixInterpretation Specifies how to interpret the values in the matrix (e.g. as 8-bit integers, 16-bit floats, etc).
/// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major).
/// @param transpose Whether to transpose the matrix before multiplication.
/// @param matrixStride The stride in bytes between rows/columns of the matrix.
/// @return A new cooperative vector containing the result of the matrix multiplication.
+/// @remarks Depending on target hardware, some combinations of `inputInterpretation`, `matrixInterpretation` and `memoryLayout` may not be supported.
+/// For example, CoopVecComponentType.Float32 is not widely supported. Developers should query device properties through the host graphics API to
+/// find out which interpretations are supported.
+///
+/// Transposing is not supported when `memoryLayout` is `RowMajor` or `ColumnMajor`, and `transpose` must be `false`.
+/// Not all component types support transposing.
+/// When `memoryLayout` is `InferencingOptimal` or `TrainingOptimal`, `matrixStride` is ignored.
[ForceInline]
[require(cooperative_vector)]
[require(hlsl_coopvec_poc)]
@@ -26650,7 +26658,9 @@ CoopVec<T, M> coopVecMatMul(
matrixStride);
}
-/// Multiply a cooperative vector with a matrix and add a bias vector.
+/// Multiply a matrix with a cooperative vector and add a bias vector to the result.
+/// Given a M-row by K-col `matrix`, a K-element column vector `input`, and a M-element vector `bias`, computes `matrix * input + bias`, and
+/// returns a M-element vector.
/// @param input The input cooperative vector to multiply with the matrix.
/// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as packed values).
/// @param k The number of columns in the matrix.
@@ -26667,6 +26677,14 @@ CoopVec<T, M> coopVecMatMul(
/// @remarks Unlike coopVecMatMulAdd, this function supports packed input interpretations where multiple values
/// can be packed into each element of the input vector. The k parameter specifies the actual number of
/// values to use from the packed input.
+///
+/// Depending on target hardware, some combinations of `inputInterpretation`, `matrixInterpretation` and `memoryLayout` may not be supported.
+/// For example, CoopVecComponentType.Float32 is not widely supported. Developers should query device properties through the host graphics API to
+/// find out which interpretations are supported.
+///
+/// Transposing is not supported when `memoryLayout` is `RowMajor` or `ColumnMajor`, and `transpose` must be `false`.
+/// Not all component types support transposing.
+/// When `memoryLayout` is `InferencingOptimal` or `TrainingOptimal`, `matrixStride` is ignored.
[ForceInline]
[require(cooperative_vector)]
[require(hlsl_coopvec_poc)]
@@ -26804,7 +26822,9 @@ CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, l
}
}
-/// Multiply a cooperative vector with a matrix and add a bias vector.
+/// Multiply a matrix with a cooperative vector and add a bias vector.
+/// Given a M-row by K-col `matrix`, a K-element column vector `input`, and a M-element vector `bias`, computes `matrix * input + bias`, and
+/// returns a M-element vector.
/// @param input The input cooperative vector to multiply with the matrix.
/// @param inputInterpretation Specifies how to interpret the values in the input vector.
/// @param matrix The matrix buffer to multiply with.
@@ -26817,6 +26837,13 @@ CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, l
/// @param transpose Whether to transpose the matrix before multiplication.
/// @param matrixStride The stride between matrix rows/columns in bytes.
/// @return A new cooperative vector containing the result of the matrix multiplication plus bias.
+/// @remarks Depending on target hardware, some combinations of `inputInterpretation`, `matrixInterpretation` and `memoryLayout` may not be supported.
+/// For example, CoopVecComponentType.Float32 is not widely supported. Developers should query device properties through the host graphics API to
+/// find out which interpretations are supported.
+///
+/// Transposing is not supported when `memoryLayout` is `RowMajor` or `ColumnMajor`, and `transpose` must be `false`.
+/// Not all component types support transposing.
+/// When `memoryLayout` is `InferencingOptimal` or `TrainingOptimal`, `matrixStride` is ignored.
[ForceInline]
[require(cooperative_vector)]
[require(hlsl_coopvec_poc)]
@@ -26862,7 +26889,9 @@ ${{{{
if(buffer.isRW)
{
}}}}
-/// Accumulate the outer product of two cooperative vectors into a matrix.
+/// Atomically accumulates the outer product of two cooperative vectors into a matrix. Given an M-element vector `a`, and an N-element vector `b`,
+/// compute the outer product of `a` and `b`, forming a M-row by N-col matrix. The elements in the matrix is then atomically accumulated
+/// to memory location represented by `matrix`.
/// @param a The first cooperative vector.
/// @param b The second cooperative vector.
/// @param matrix The matrix buffer to accumulate the result into.
@@ -26870,6 +26899,21 @@ if(buffer.isRW)
/// @param matrixStride The stride between matrix rows/columns in bytes.
/// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major).
/// @param matrixInterpretation Specifies how to interpret the values in the matrix.
+/// @remarks On current hardware, `memoryLayout` must be `TrainingOptimal`.
+///
+/// When `memoryLayout` is `RowMajor`, this function is equivalent to:
+///
+/// ```
+/// uint8_t* matrixPtr = matrix + matrixOffset;
+/// for (int i = 0; i < M; i++)
+/// {
+/// for (int j = 0; j < N; j++)
+/// {
+/// let elem = a[i] * b[j];
+/// atomicAdd(matrixPtr + i * matrixStride + j * sizeof(T), elem);
+/// }
+/// }
+/// ```
[require(cooperative_vector)]
[require(hlsl_coopvec_poc)]
[require(optix_coopvec)]
@@ -26959,10 +27003,15 @@ void coopVecOuterProductAccumulate<T : __BuiltinArithmeticType, let M : int, let
}
}
-/// Accumulate the sum of a cooperative vector into a buffer at the specified offset.
+/// Atomically accumulates the elements a cooperative vector into a buffer at the specified offset.
/// @param v The cooperative vector to sum.
/// @param buffer The buffer to accumulate the sum into.
/// @param offset Byte offset into the buffer.
+/// @remarks This function is equivalent to:
+/// ```
+/// for (int i = 0; i < N; i++)
+/// atomicAdd(dest[i], v[i]);
+/// ```
[require(cooperative_vector)]
[require(hlsl_coopvec_poc)]
[require(optix_coopvec)]
@@ -27015,20 +27064,6 @@ static const struct {
for(auto buffer : kStructuredBufferCases_) {
}}}}
-/// Multiply a cooperative vector with a matrix and return the result.
-/// @param input The input cooperative vector to multiply with the matrix.
-/// @param inputInterpretation Specifies how to interpret the values in the input vector (e.g. as packed values).
-/// @param k The number of columns in the matrix.
-/// @param matrix The matrix buffer to multiply with.
-/// @param matrixOffset Byte offset into the matrix buffer.
-/// @param matrixInterpretation Specifies how to interpret the values in the matrix.
-/// @param memoryLayout Specifies the memory layout of the matrix (row-major or column-major).
-/// @param transpose Whether to transpose the matrix before multiplication.
-/// @param matrixStride The stride between matrix rows/columns in bytes.
-/// @return A new cooperative vector containing the result of the matrix multiplication.
-/// @remarks Unlike coopVecMatMul, this function supports packed input interpretations where multiple values
-/// can be packed into each element of the input vector. The k parameter specifies the actual number of
-/// values to use from the packed input.
[require(spirv, cooperative_vector)]
__generic<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType,IgnoredBufferElementType>
CoopVec<T, M> coopVecMatMulPacked(
@@ -27288,6 +27323,185 @@ void coopVecReduceSumAccumulate<T : __BuiltinArithmeticType, let N : int, U, let
}
}
+// Pointer overloads for coopvector operations.
+
+[require(spirv, cooperative_vector)]
+__generic<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType>
+CoopVec<T, M> coopVecMatMulPacked(
+ CoopVec<U, PackedK> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ constexpr int k,
+ void* matrixPtr,
+ constexpr CoopVecComponentType matrixInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+)
+{
+ static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK
+ , "for non-packed inputInterpretation values k must be equal to the input vector length");
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK
+ , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor");
+ __target_switch
+ {
+ case spirv:
+ let m : int32_t = M;
+ let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation);
+ let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation);
+ let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout);
+ int operands = 0; // NoneKHR
+ let zero = 0;
+ let cvtMatPtr = (Ptr<T[]>)matrixPtr;
+ if (__isSignedInt<T>())
+ {
+ operands |= 0x08; // MatrixResultSignedComponentsKHR
+ }
+ return spirv_asm
+ {
+ result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulNV $input $inputInterpretationSpirv $cvtMatPtr $zero $matrixInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride !operands;
+ };
+ }
+}
+
+// specialized coopVecMatMul for non-packed inputs
+[require(spirv, cooperative_vector)]
+__generic<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType>
+CoopVec<T, M> coopVecMatMul(
+ CoopVec<U, K> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ void* matrix,
+ constexpr CoopVecComponentType matrixInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+)
+{
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually");
+ return coopVecMatMulPacked<
+ T, M, K, U>(
+ input,
+ inputInterpretation,
+ K,
+ matrix,
+ matrixInterpretation,
+ memoryLayout,
+ transpose,
+ matrixStride);
+}
+
+[require(spirv, cooperative_vector)]
+CoopVec<T, M> coopVecMatMulAddPacked<T : __BuiltinArithmeticType, let M : int, let PackedK : int, U : __BuiltinArithmeticType>(
+ CoopVec<U, PackedK> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ constexpr int k,
+ void* matrixPtr,
+ constexpr CoopVecComponentType matrixInterpretation,
+ void* biasPtr,
+ constexpr CoopVecComponentType biasInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+)
+{
+ static_assert(__isPackedInputInterpretation(inputInterpretation) || k == PackedK
+ , "for non-packed inputInterpretation values k must be equal to the input vector length");
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ || k <= __inputInterpretationPackingFactor(inputInterpretation)*PackedK
+ , "for packed inputInterpretation values k must be less than or equal to the input vector length times the packing factor");
+
+ __target_switch
+ {
+ case spirv:
+ let m : int32_t = M;
+ let matrixInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(matrixInterpretation);
+ let biasInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(biasInterpretation);
+ let inputInterpretationSpirv : int32_t = __getSpvCoopVecComponentType(inputInterpretation);
+ let memoryLayoutSpirv : int32_t = __getSpvCoopVecMatrixLayout(memoryLayout);
+ let zero : int32_t = 0;
+ let cvtMatPtr = (Ptr<T[]>)matrixPtr;
+ let cvtBiasPtr = (Ptr<T[]>)biasPtr;
+ return spirv_asm
+ {
+ result:$$CoopVec<T, M> = OpCooperativeVectorMatrixMulAddNV $input $inputInterpretationSpirv $cvtMatPtr $zero $matrixInterpretationSpirv $cvtBiasPtr $zero $biasInterpretationSpirv $m $k $memoryLayoutSpirv $transpose $matrixStride;
+ };
+ }
+}
+
+[require(spirv, cooperative_vector)]
+CoopVec<T, M> coopVecMatMulAdd<T : __BuiltinArithmeticType, let M : int, let K : int, U : __BuiltinArithmeticType>(
+ CoopVec<U, K> input,
+ constexpr CoopVecComponentType inputInterpretation,
+ void* matrix,
+ constexpr CoopVecComponentType matrixInterpretation,
+ void* bias,
+ constexpr CoopVecComponentType biasInterpretation,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr bool transpose,
+ constexpr uint matrixStride
+)
+{
+ static_assert(!__isPackedInputInterpretation(inputInterpretation)
+ , "for packed inputInterpretation values please use coopVecMatMulPacked and specify k manually");
+ return coopVecMatMulAddPacked<
+ T, M, K, U>(
+ input,
+ inputInterpretation,
+ K,
+ matrix,
+ matrixInterpretation,
+ bias,
+ biasInterpretation,
+ memoryLayout,
+ transpose,
+ matrixStride);
+}
+
+[require(spirv, cooperative_vector_training)]
+void coopVecOuterProductAccumulate<T : __BuiltinArithmeticType, let M : int, let N : int>(
+ CoopVec<T, M> a,
+ CoopVec<T, N> b,
+ void* matrixPtr,
+ constexpr uint matrixStride,
+ constexpr CoopVecMatrixLayout memoryLayout,
+ constexpr CoopVecComponentType matrixInterpretation,
+)
+{
+ let zero : int32_t = 0;
+ __target_switch
+ {
+ case spirv:
+ let matrixInterpretationSpirv : int = __getSpvCoopVecComponentType(matrixInterpretation);
+ let memoryLayoutSpirv : int = __getSpvCoopVecMatrixLayout(memoryLayout);
+ let cvtMatrixPtr = (Ptr<T[]>)matrixPtr;
+ spirv_asm
+ {
+ OpCapability CooperativeVectorTrainingNV;
+ OpCooperativeVectorOuterProductAccumulateNV $cvtMatrixPtr $zero $a $b $memoryLayoutSpirv $matrixInterpretationSpirv $matrixStride;
+ };
+ }
+}
+
+[require(spirv, cooperative_vector_training)]
+void coopVecReduceSumAccumulate<T : __BuiltinArithmeticType, let N : int>(
+ CoopVec<T, N> v,
+ void* buffer
+)
+{
+ let zero : int32_t = 0;
+ let bufferPtr = (Ptr<T[]>)(buffer);
+ __target_switch
+ {
+ case spirv:
+ spirv_asm
+ {
+ OpCapability CooperativeVectorTrainingNV;
+ OpCooperativeVectorReduceSumAccumulateNV $bufferPtr $zero $v;
+ };
+ }
+}
+
//@public:
/// Mark a variable as being workgroup uniform.
@@ -28126,3 +28340,24 @@ uint packHalf2x16(half2 unpackedValue)
{
return packHalf2x16(float2(unpackedValue));
}
+
+[require(spirv)]
+void InterlockedAddF16Emulated(half* dest, half value, out half originalValue)
+{
+ let buf = (half2*)(dest);
+ uint64_t byteAddress = (uint64_t)dest;
+ if ((byteAddress & 3) == 0)
+ {
+ originalValue = __atomic_add(*buf, half2(value, half(0.0))).x;
+ }
+ else
+ {
+ originalValue = __atomic_add(*buf, half2(half(0.0), value)).y;
+ }
+}
+
+[require(spirv)]
+void InterlockedAddF16x2(half2* dest, half2 value, out half2 originalValue)
+{
+ originalValue = __atomic_add(*dest, value);
+} \ No newline at end of file
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 88dea0b7e..56f9d873a 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -187,6 +187,13 @@ class SynthesizedStaticLambdaFuncModifier : public Modifier
FIDDLE(...)
};
+FIDDLE()
+class ExplicitlyDeclaredCapabilityModifier : public Modifier
+{
+ FIDDLE(...)
+ FIDDLE() CapabilitySet declaredCapabilityRequirements;
+};
+
// Marks a synthesized variable as local temporary variable.
FIDDLE()
class LocalTempVarModifier : public Modifier
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index 506abc1be..6456dbe98 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -419,6 +419,9 @@ bool SemanticsVisitor::createInvokeExprForSynthesizedCtor(
if (!structDecl)
return false;
+ if (!structDecl->checkState.isBeingChecked())
+ ensureDecl(structDecl, DeclCheckState::AttributesChecked);
+
HashSet<Type*> isVisit;
bool isCStyle = false;
if (!_getSynthesizedConstructor(
@@ -656,8 +659,8 @@ bool SemanticsVisitor::_readAggregateValueFromInitializerList(
auto toMakeArrayFromElementExpr = m_astBuilder->create<MakeArrayFromElementExpr>();
toMakeArrayFromElementExpr->loc = fromInitializerListExpr->loc;
toMakeArrayFromElementExpr->type = QualType(toType);
-
- *outToExpr = toMakeArrayFromElementExpr;
+ if (outToExpr)
+ *outToExpr = toMakeArrayFromElementExpr;
return true;
}
for (UInt ee = 0; ee < elementCount; ++ee)
@@ -748,8 +751,8 @@ bool SemanticsVisitor::_readAggregateValueFromInitializerList(
auto defaultConstructExpr = m_astBuilder->create<DefaultConstructExpr>();
defaultConstructExpr->loc = fromInitializerListExpr->loc;
defaultConstructExpr->type = QualType(toType);
-
- *outToExpr = defaultConstructExpr;
+ if (outToExpr)
+ *outToExpr = defaultConstructExpr;
return true;
}
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 1a70e25d7..0dd859bb2 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -2914,9 +2914,9 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness
context->parentDecl->findLastDirectMemberDeclOfName(requirementDeclRef.getName()))
{
// Remove the `ToBeSynthesizedModifier`.
- if (as<ToBeSynthesizedModifier>(existingDecl->modifiers.first))
+ if (auto mod = existingDecl->modifiers.findModifier<ToBeSynthesizedModifier>())
{
- existingDecl->modifiers.first = existingDecl->modifiers.first->next;
+ removeModifier(existingDecl, mod);
}
else
{
@@ -3133,14 +3133,9 @@ bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness
addModifier(aggTypeDecl, m_astBuilder->create<SynthesizedModifier>());
- // The visibility of synthesized decl should be the min of the parent decl and the requirement.
- if (requirementDeclRef.getDecl()->findModifier<VisibilityModifier>())
- {
- auto requirementVisibility = getDeclVisibility(requirementDeclRef.getDecl());
- auto thisVisibility = getDeclVisibility(context->parentDecl);
- auto visibility = Math::Min(thisVisibility, requirementVisibility);
- addVisibilityModifier(aggTypeDecl, visibility);
- }
+ // The visibility of synthesized decl should be the same of the parent decl.
+ auto thisVisibility = getDeclVisibility(context->parentDecl);
+ addVisibilityModifier(aggTypeDecl, thisVisibility);
// Synthesize the rest of IDifferential method conformances by recursively checking
// conformance on the synthesized decl.
@@ -4149,8 +4144,12 @@ bool SemanticsVisitor::doesVarMatchRequirement(
return false;
}
- auto satisfyingVal =
- tryConstantFoldDeclRef(satisfyingMemberDeclRef, ConstantFoldingKind::LinkTime, nullptr);
+ IntVal* satisfyingVal = nullptr;
+ if (isValidCompileTimeConstantType(satisfyingType))
+ {
+ satisfyingVal =
+ tryConstantFoldDeclRef(satisfyingMemberDeclRef, ConstantFoldingKind::LinkTime, nullptr);
+ }
if (satisfyingVal)
{
witnessTable->add(requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingVal));
@@ -5125,9 +5124,9 @@ void SemanticsVisitor::markOverridingDecl(
return;
}
+ memberDecl = maybeGetInner(memberDecl);
if (hasDefaultImpl(requiredMemberDeclRef))
{
- memberDecl = maybeGetInner(memberDecl);
// If the required member has a default implementation,
// we need to make sure the member we found is marked as 'override'.
if (!memberDecl->hasModifier<OverrideModifier>())
@@ -14290,6 +14289,9 @@ void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* fun
}
else
{
+ auto declaredCapModifier = m_astBuilder->create<ExplicitlyDeclaredCapabilityModifier>();
+ declaredCapModifier->declaredCapabilityRequirements = declaredCaps;
+ addModifier(funcDecl, declaredCapModifier);
if (vis == DeclVisibility::Public)
{
// For public decls, we need to enforce that the function
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 306687bd8..b90081af8 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -682,6 +682,8 @@ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult(
auto typeDef = m_astBuilder->create<TypeAliasDecl>();
typeDef->nameAndLoc.name = getName("Differential");
typeDef->parentDecl = structDecl;
+ addVisibilityModifier(structDecl, getDeclVisibility(parent));
+ addVisibilityModifier(typeDef, getDeclVisibility(parent));
auto synthDeclRef =
createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl));
@@ -714,6 +716,7 @@ Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult(
typeDef->type.type =
calcThisType(subType->getDeclRef().getDecl()->getDefaultDeclRef());
+ addVisibilityModifier(typeDef, getDeclVisibility(parent));
synthesizedDecl = parent;
parent->addDirectMemberDecl(typeDef);
@@ -2085,7 +2088,7 @@ IntVal* SemanticsVisitor::tryConstantFoldDeclRef(
// to not allow such cases.
//
// Note that float-to-inst casts for non-`IntVal`s are allowed.
- if (!isScalarIntegerType(decl->getType()))
+ if (!isValidCompileTimeConstantType(decl->getType()))
{
getSink()->diagnose(declRef, Diagnostics::intValFromNonIntSpecConstEncountered);
return nullptr;
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index e6744071b..a360361f7 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -553,6 +553,16 @@ void validateEntryPoint(EntryPoint* entryPoint, DiagnosticSink* sink)
targetOptionSet.hasOption(CompilerOptionName::Capability) &&
(targetOptionSet.getIntOption(CompilerOptionName::Capability) !=
SLANG_CAPABILITY_UNKNOWN);
+
+ if (auto declaredCapsMod =
+ entryPointFuncDecl->findModifier<ExplicitlyDeclaredCapabilityModifier>())
+ {
+ // If the entry point has an explicitly declared capability, then we
+ // will merge that with the target capability set before checking if
+ // there is an implicit upgrade.
+ targetCaps.nonDestructiveJoin(declaredCapsMod->declaredCapabilityRequirements);
+ }
+
// Only attempt to error if a specific profile or capability is requested
if ((specificCapabilityRequested || specificProfileRequested) &&
targetCaps.atLeastOneSetImpliedInOther(