diff options
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 76 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 3 | ||||
| -rw-r--r-- | tests/bugs/invalid-entrypoint-param.slang | 17 | ||||
| -rw-r--r-- | tests/language-feature/shader-params/entry-point-uniform-params-implicit.slang | 42 | ||||
| -rw-r--r-- | tests/reflection/attribute.slang.expected | 6 |
5 files changed, 104 insertions, 40 deletions
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index a802906a7..99205e522 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -333,7 +333,17 @@ namespace Slang bool isBuiltinParameterType(Type* type) { - return as<BuiltinType>(type) != nullptr; + if (!as<BuiltinType>(type)) + return false; + if (as<BasicExpressionType>(type)) + return false; + if (as<VectorExpressionType>(type)) + return false; + if (as<MatrixExpressionType>(type)) + return false; + if (auto arrayType = as<ArrayExpressionType>(type)) + return isBuiltinParameterType(arrayType->getElementType()); + return true; } bool doStructFieldsHaveSemanticImpl(Type* type, HashSet<Type*>& seenTypes) @@ -345,18 +355,20 @@ namespace Slang if (!structDecl) return false; seenTypes.add(type); + bool hasFields = false; for (auto field : structDecl->getFields()) { + hasFields = true; if (!field->findModifier<HLSLSemantic>()) { - if (!seenTypes.contains(type)) + if (!seenTypes.contains(field->getType())) { if (!doStructFieldsHaveSemanticImpl(field->getType(), seenTypes)) return false; } } } - return true; + return hasFields; } bool doStructFieldsHaveSemantic(Type* type) @@ -488,30 +500,58 @@ namespace Slang } } + bool canHaveVaryingInput = false; + switch (stage) + { + case Stage::Vertex: + case Stage::Fragment: + case Stage::Miss: + case Stage::AnyHit: + case Stage::ClosestHit: + case Stage::Callable: + case Stage::Geometry: + case Stage::Mesh: + case Stage::Hull: + case Stage::Domain: + canHaveVaryingInput = true; + break; + default: + break; + } + for (const auto& param : entryPointFuncDecl->getParameters()) { if (isUniformParameterType(param->getType())) { // Automatically add `uniform` modifier to entry point parameters. if (!param->hasModifier<HLSLUniformModifier>()) - addModifier(param, getCurrentASTBuilder()->create<HLSLUniformModifier>()); - } - else if (isBuiltinParameterType(param->getType())) - { - } - else - { - // For all non-uniform parameters of a general type, we require the parameter be associated with - // a system value semantic. - if (!param->hasModifier<HLSLUniformModifier>()) { - if (!param->findModifier<HLSLSemantic>()) - { - if (!doStructFieldsHaveSemantic(param->getType())) - sink->diagnose(param, Diagnostics::nonUniformEntryPointParameterMustHaveSemantic, param->getName()); - } + addModifier(param, getCurrentASTBuilder()->create<HLSLUniformModifier>()); + continue; } } + + if (canHaveVaryingInput) + continue; + + // If the stage doesn't allow varying input/output, + // we require the parameter to be associated with a system value semantic. + if (param->hasModifier<HLSLUniformModifier>()) + continue; + if (param->findModifier<HLSLSemantic>()) + continue; + + bool isBuiltinType = isBuiltinParameterType(param->getType()); + if (isBuiltinType) + continue; + + if (doStructFieldsHaveSemantic(param->getType())) + continue; + + // The user is defining a parameter with no 'uniform' modifier for a stage that doesn't support + // varying input/output. We will automatically convert it to a 'uniform' parameter, and diagnose a warning. + addModifier(param, getCurrentASTBuilder()->create<HLSLUniformModifier>()); + sink->diagnose(param, Diagnostics::nonUniformEntryPointParameterTreatedAsUniform, param->getName()); } for (auto target : linkage->targets) diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index cb3e39dd4..487f264d5 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -649,7 +649,8 @@ DIAGNOSTIC(38033, Error, cannotUseNoDiffInNonDifferentiableFunc, "cannot use 'no DIAGNOSTIC(38034, Error, cannotUseConstRefOnDifferentiableParameter, "cannot use '__constref' on a differentiable parameter.") DIAGNOSTIC(38034, Error, cannotUseConstRefOnDifferentiableMemberMethod, "cannot use '[constref]' on a differentiable member method of a differentiable type.") -DIAGNOSTIC(38040, Error, nonUniformEntryPointParameterMustHaveSemantic, "non-uniform parameter '$0' must have a system-value semantic.") +DIAGNOSTIC(38040, Warning, nonUniformEntryPointParameterTreatedAsUniform, "parameter '$0' is treated as 'uniform' because it does not have a system-value semantic.") + DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself") DIAGNOSTIC(39999, Error, errorInImportedModule, "import of module '$0' failed because of a compilation error") diff --git a/tests/bugs/invalid-entrypoint-param.slang b/tests/bugs/invalid-entrypoint-param.slang deleted file mode 100644 index bafb0fb88..000000000 --- a/tests/bugs/invalid-entrypoint-param.slang +++ /dev/null @@ -1,17 +0,0 @@ -//TEST:SIMPLE(filecheck=CHECK): -target spirv - -// `TT` is not valid for defining a varying entrypoint parameter, -// and we should diagnose an error. - -// CHECK: error 39028 - -struct TT -{ - Texture2D tex; -} - -[numthreads(1, 1, 1)] -void f(TT t) -{ - return; -}
\ No newline at end of file diff --git a/tests/language-feature/shader-params/entry-point-uniform-params-implicit.slang b/tests/language-feature/shader-params/entry-point-uniform-params-implicit.slang new file mode 100644 index 000000000..5f8ac7edd --- /dev/null +++ b/tests/language-feature/shader-params/entry-point-uniform-params-implicit.slang @@ -0,0 +1,42 @@ +// entry-point-uniform-params-implicit.slang + +// Test that slang can treat a compute shader parameter as `uniform` without explicit `uniform` keyword. + +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -xslang -Wno-38040 +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -xslang -Wno-38040 +//TEST:SIMPLE(filecheck=WARNING): -target spirv + +struct Data +{ + int a; + int b; +} + +int test(int val, int a, int b) +{ + return a*(val+1) + b*(val+2); +} + +[numthreads(4, 1, 1)] +[shader("compute")] +void computeMain( + +//TEST_INPUT:uniform(data=[256 1]):name=d +// WARNING: ([[# @LINE+1]]): warning 38040 + Data d, + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer + uniform RWStructuredBuffer<int> outputBuffer, + + int3 dispatchThreadID : SV_DispatchThreadID) +{ + int tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal, d.a, d.b); + outputBuffer[tid] = outVal; + + // CHECK: 102 + // CHECK: 203 + // CHECK: 304 + // CHECK: 405 +} diff --git a/tests/reflection/attribute.slang.expected b/tests/reflection/attribute.slang.expected index 9b51aae32..a5bce747b 100644 --- a/tests/reflection/attribute.slang.expected +++ b/tests/reflection/attribute.slang.expected @@ -288,8 +288,7 @@ standard output = { ] } ], - "stage": "compute", - "binding": {"kind": "varyingInput", "index": 0}, + "binding": {"kind": "uniform", "offset": 0, "size": 4}, "type": { "kind": "scalar", "scalarType": "float32" @@ -304,8 +303,7 @@ standard output = { ] } ], - "stage": "compute", - "binding": {"kind": "varyingInput", "index": 1}, + "binding": {"kind": "uniform", "offset": 4, "size": 4}, "type": { "kind": "scalar", "scalarType": "float32" |
