diff options
| author | Yong He <yonghe@outlook.com> | 2025-07-21 21:35:44 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-22 04:35:44 +0000 |
| commit | 9d47a352960efd71494c7dfa0918debd5b405077 (patch) | |
| tree | f0acf898cb5c4de8a1951ac8010168b119bf94ff | |
| parent | 9adac4069fbcc7ce5bea2c42d19c61eb1dcd7f25 (diff) | |
Fix Conditioanl<T, false> fields with a semantic. (#7855)
* Fix Conditioanl<T, false> fields with a semantic.
* Add unit test.
* Fix test.
| -rw-r--r-- | source/slang/slang-check-impl.h | 19 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-check-type.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-parameter-binding.cpp | 25 | ||||
| -rw-r--r-- | tests/language-feature/types/conditional-varying.slang | 23 | ||||
| -rw-r--r-- | tests/spirv/spec-constant-generic.slang | 10 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-conditional-vertex-input.cpp | 146 |
7 files changed, 231 insertions, 21 deletions
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 5c6be2665..81dca7312 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1444,7 +1444,18 @@ public: Type* ExtractGenericArgType(Expr* exp); - IntVal* ExtractGenericArgInteger(Expr* exp, Type* genericParamType, DiagnosticSink* sink); + enum class ConstantFoldingKind + { + CompileTime, + LinkTime, + SpecializationConstant + }; + + IntVal* ExtractGenericArgInteger( + Expr* exp, + Type* genericParamType, + ConstantFoldingKind kind, + DiagnosticSink* sink); IntVal* ExtractGenericArgInteger(Expr* exp, Type* genericParamType); Val* ExtractGenericArgVal(Expr* exp); @@ -2184,12 +2195,6 @@ public: Expr* checkPredicateExpr(Expr* expr); - enum class ConstantFoldingKind - { - CompileTime, - LinkTime, - SpecializationConstant - }; Expr* checkExpressionAndExpectIntegerConstant( Expr* expr, IntVal** outIntVal, diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 52b0ef5bc..18fb9798b 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -270,12 +270,31 @@ bool SemanticsVisitor::TryCheckOverloadCandidateVisibility( return true; } +static bool isArrayDecl(Decl* decl) +{ + if (auto magicMod = decl->findModifier<MagicTypeModifier>()) + { + if (magicMod->magicNodeType.getTag() == ASTNodeType::ArrayExpressionType) + return true; + } + return false; +} + bool SemanticsVisitor::TryCheckGenericOverloadCandidateTypes( OverloadResolveContext& context, OverloadCandidate& candidate) { auto genericDeclRef = candidate.item.declRef.as<GenericDecl>(); + // All generic arguments, except array sizes, need to be at least a link-time constant. + // Exception: array sizes can also be a specialization constant. + // + ConstantFoldingKind argFoldingKind = ConstantFoldingKind::LinkTime; + if (isArrayDecl(genericDeclRef.getDecl())) + { + argFoldingKind = ConstantFoldingKind::SpecializationConstant; + } + // Only allow constructing a PartialGenericAppExpr when referencing a callable decl. // Other types of generic decls must be fully specified. bool allowPartialGenericApp = false; @@ -497,6 +516,7 @@ bool SemanticsVisitor::TryCheckGenericOverloadCandidateTypes( val = ExtractGenericArgInteger( arg, getType(m_astBuilder, valParamRef), + argFoldingKind, context.mode == OverloadResolveContext::Mode::JustTrying ? nullptr : getSink()); } diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index bdf9c829a..82f9596e6 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -146,6 +146,7 @@ Type* SemanticsVisitor::ExtractGenericArgType(Expr* exp) IntVal* SemanticsVisitor::ExtractGenericArgInteger( Expr* exp, Type* genericParamType, + ConstantFoldingKind kind, DiagnosticSink* sink) { IntVal* val = CheckIntegerConstantExpression( @@ -153,7 +154,7 @@ IntVal* SemanticsVisitor::ExtractGenericArgInteger( genericParamType ? IntegerConstantExpressionCoercionType::SpecificType : IntegerConstantExpressionCoercionType::AnyInteger, genericParamType, - ConstantFoldingKind::SpecializationConstant, + kind, sink); if (val) return val; @@ -168,7 +169,11 @@ IntVal* SemanticsVisitor::ExtractGenericArgInteger( IntVal* SemanticsVisitor::ExtractGenericArgInteger(Expr* exp, Type* genericParamType) { - return ExtractGenericArgInteger(exp, genericParamType, getSink()); + return ExtractGenericArgInteger( + exp, + genericParamType, + ConstantFoldingKind::LinkTime, + getSink()); } Val* SemanticsVisitor::ExtractGenericArgVal(Expr* exp) diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index 4ba9e4e03..06d2b1f34 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -2216,7 +2216,13 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( // A matrix is processed as if it was an array of rows else if (auto matrixType = as<MatrixExpressionType>(type)) { - auto rowCount = getIntVal(matrixType->getRowCount()); + auto foldedRowCountVal = + context->getTargetProgram()->getProgram()->tryFoldIntVal(matrixType->getRowCount()); + IntegerLiteralValue rowCount = 0; + if (!foldedRowCountVal) + { + rowCount = getIntVal(foldedRowCountVal); + } return processSimpleEntryPointParameter( context, matrixType, @@ -2228,10 +2234,15 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( { // Note: Bad Things will happen if we have an array input // without a semantic already being enforced. + UInt elementCount = 0; - auto elementCount = (UInt)getIntVal(arrayType->getElementCount()); - if (arrayType->isUnsized()) - elementCount = 0; + if (!arrayType->isUnsized()) + { + auto intVal = context->getTargetProgram()->getProgram()->tryFoldIntVal( + arrayType->getElementCount()); + if (intVal) + elementCount = (UInt)getIntVal(intVal); + } // We use the first element to derive the layout for the element type auto elementTypeLayout = processEntryPointVaryingParameter( @@ -2456,7 +2467,11 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( // for (auto fieldTypeResInfo : fieldTypeLayout->resourceInfos) { - SLANG_RELEASE_ASSERT(fieldTypeResInfo.count != 0); + // If the field is a Conditional<T, false> type, then it could have 0 size. + // We should skip this field if it has no use of layout units. + if (fieldTypeResInfo.count == 0) + continue; + auto kind = fieldTypeResInfo.kind; auto structTypeResInfo = structLayout->findOrAddResourceInfo(kind); diff --git a/tests/language-feature/types/conditional-varying.slang b/tests/language-feature/types/conditional-varying.slang new file mode 100644 index 000000000..674ae96e7 --- /dev/null +++ b/tests/language-feature/types/conditional-varying.slang @@ -0,0 +1,23 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv +//TEST:SIMPLE(filecheck=HLSL): -target hlsl -entry fragMain -profile ps_6_0 + +// CHECK: OpEntryPoint +// HLSL: float4 fragMain() : SV_TARGET + +extern static const bool enableConditional = false; + +struct Vertex +{ + Conditional<float3, enableConditional> color : COLOR; +} + +[shader("fragment")] +float4 fragMain(Vertex v) : SV_Target +{ + if (let c = v.color.get()) + { + // This block should not be executed. + return float4(c, 1.0f); + } + return float4(0.0f, 0.0f, 0.0f, 1.0f); +}
\ No newline at end of file diff --git a/tests/spirv/spec-constant-generic.slang b/tests/spirv/spec-constant-generic.slang index 9a9f7006f..1d7e3c1fe 100644 --- a/tests/spirv/spec-constant-generic.slang +++ b/tests/spirv/spec-constant-generic.slang @@ -1,14 +1,13 @@ //TEST:SIMPLE(filecheck=CHECK): -target spirv //TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -output-using-type -emit-spirv-directly -// CHECK: %[[C0:[0-9A-Za-z_]+]] = OpSpecConstant %int 32 +// CHECK: %[[C0:[0-9A-Za-z_]+]] = OpConstant %int 32 // CHECK: %[[C1:[0-9A-Za-z_]+]] = OpSpecConstant %int 2 // CHECK: %[[COP0:[0-9A-Za-z_]+]] = OpSpecConstantOp %int SDiv %[[C0]] %[[C1]] // CHECK: %[[ARR_TYPE:[0-9A-Za-z_]+]] = OpTypeArray %float %[[COP0]] // CHECK: %[[PT_TYPE:[0-9A-Za-z_]+]] = OpTypePointer Function %[[ARR_TYPE]] -[SpecializationConstant] -const int constValue0 = 32; +static const int constValue0 = 32; [SpecializationConstant] const int constValue1 = 2; @@ -33,11 +32,8 @@ struct MyStruct<let N: int> [numthreads(1, 1, 1)] void computeMain() { - // This test checks we can use spec constants for generic arguments, and also - // we can show that the array size is computed correctly. - // The function call shows that the two arrays are the same type. + // This test checks we can use spec constants for array sizes. MyStruct<constValue0> s; - // CHECK: OpVariable %[[PT_TYPE]] Function func(s.buffer); diff --git a/tools/slang-unit-test/unit-test-conditional-vertex-input.cpp b/tools/slang-unit-test/unit-test-conditional-vertex-input.cpp new file mode 100644 index 000000000..d52369152 --- /dev/null +++ b/tools/slang-unit-test/unit-test-conditional-vertex-input.cpp @@ -0,0 +1,146 @@ +// unit-test-unit-test-conditional-vertex-input.cpp + +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +#include <stdio.h> +#include <stdlib.h> + +using namespace Slang; + +// Test the compilation API for compiling an entrypoint that uses `Conditional<T>` to +// represent conditional vertex attribute input that can be specialized away. + +SLANG_UNIT_TEST(conditionalVertexInput) +{ + const char* userSourceBody = R"( + struct Vertex<bool hasColor> { + float3 pos : POSITION; + Conditional<float3, hasColor> color : COLOR; + float3 normal : NORMAL; + } + + extern static const bool vertexHasColor = true; + + [shader("vertex")] + float4 vertMain(Vertex<vertexHasColor> o) { + if (let color = o.color.get()) + { + // If `vertexHasColor` is true, we can use `color`. + return float4(o.pos + color + o.normal, 1); + } + return float4(o.pos + o.normal, 1); + } + )"; + const char* userSourceBodyNoColor = R"(export static const bool vertexHasColor = false;)"; + + ComPtr<slang::IGlobalSession> globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_GLSL; + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr<slang::ISession> session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModuleFromSourceString( + "m", + "m.slang", + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + ComPtr<slang::IEntryPoint> entryPoint; + module->findAndCheckEntryPoint( + "vertMain", + SLANG_STAGE_VERTEX, + entryPoint.writeRef(), + diagnosticBlob.writeRef()); + + // Check the program with `vertexHasColor = true`. + { + slang::IComponentType* componentTypes[2] = {module, entryPoint.get()}; + ComPtr<slang::IComponentType> composedProgram; + session->createCompositeComponentType( + componentTypes, + 2, + composedProgram.writeRef(), + diagnosticBlob.writeRef()); + + ComPtr<slang::IComponentType> linkedProgram; + composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + + auto paramLayout = linkedProgram->getLayout() + ->getEntryPointByIndex(0) + ->getParameterByIndex(0) + ->getTypeLayout(); + + // Total number of varying inputs should be 3. (pos, color and normal) + SLANG_CHECK(paramLayout->getSize(slang::ParameterCategory::VaryingInput) == 3); + + // Offset of `normal` should be 2. + SLANG_CHECK( + paramLayout + ->getFieldByIndex(2) // `o.normal` + ->getOffset(slang::ParameterCategory::VaryingInput) == 2); + ComPtr<slang::IBlob> code; + SLANG_CHECK( + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()) == + SLANG_OK); + auto codeStr = Slang::UnownedStringSlice((const char*)code->getBufferPointer()); + SLANG_CHECK(codeStr.indexOf(toSlice("layout(location = 0)")) != -1); + SLANG_CHECK(codeStr.indexOf(toSlice("layout(location = 1)")) != -1); + SLANG_CHECK(codeStr.indexOf(toSlice("layout(location = 2)")) != -1); + } + + // Check the program with `vertexHashColor = false`. + { + auto configModule = session->loadModuleFromSourceString( + "config", + "config.slang", + userSourceBodyNoColor, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + slang::IComponentType* componentTypes[3] = {module, entryPoint.get(), configModule}; + ComPtr<slang::IComponentType> composedProgram; + session->createCompositeComponentType( + componentTypes, + 3, + composedProgram.writeRef(), + diagnosticBlob.writeRef()); + + ComPtr<slang::IComponentType> linkedProgram; + composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef()); + + auto paramLayout = linkedProgram->getLayout() + ->getEntryPointByIndex(0) + ->getParameterByIndex(0) + ->getTypeLayout(); + + // Total number of varying inputs should be 2. (pos and normal) + SLANG_CHECK(paramLayout->getSize(slang::ParameterCategory::VaryingInput) == 2); + + // Offset of `normal` should be 1. + SLANG_CHECK( + paramLayout + ->getFieldByIndex(2) // `o.normal` + ->getOffset(slang::ParameterCategory::VaryingInput) == 1); + ComPtr<slang::IBlob> code; + SLANG_CHECK( + linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef()) == + SLANG_OK); + + auto codeStr = Slang::UnownedStringSlice((const char*)code->getBufferPointer()); + + SLANG_CHECK(codeStr.indexOf(toSlice("layout(location = 0)")) != -1); + SLANG_CHECK(codeStr.indexOf(toSlice("layout(location = 1)")) != -1); + // Resulting code should not contain `layout(location = 1)` since `color` is not used. + SLANG_CHECK(codeStr.indexOf(toSlice("layout(location = 2)")) == -1); + } +} |
