summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-check-shader.cpp76
-rw-r--r--source/slang/slang-diagnostic-defs.h3
-rw-r--r--tests/bugs/invalid-entrypoint-param.slang17
-rw-r--r--tests/language-feature/shader-params/entry-point-uniform-params-implicit.slang42
-rw-r--r--tests/reflection/attribute.slang.expected6
5 files changed, 104 insertions, 40 deletions
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index a802906a7..99205e522 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -333,7 +333,17 @@ namespace Slang
bool isBuiltinParameterType(Type* type)
{
- return as<BuiltinType>(type) != nullptr;
+ if (!as<BuiltinType>(type))
+ return false;
+ if (as<BasicExpressionType>(type))
+ return false;
+ if (as<VectorExpressionType>(type))
+ return false;
+ if (as<MatrixExpressionType>(type))
+ return false;
+ if (auto arrayType = as<ArrayExpressionType>(type))
+ return isBuiltinParameterType(arrayType->getElementType());
+ return true;
}
bool doStructFieldsHaveSemanticImpl(Type* type, HashSet<Type*>& seenTypes)
@@ -345,18 +355,20 @@ namespace Slang
if (!structDecl)
return false;
seenTypes.add(type);
+ bool hasFields = false;
for (auto field : structDecl->getFields())
{
+ hasFields = true;
if (!field->findModifier<HLSLSemantic>())
{
- if (!seenTypes.contains(type))
+ if (!seenTypes.contains(field->getType()))
{
if (!doStructFieldsHaveSemanticImpl(field->getType(), seenTypes))
return false;
}
}
}
- return true;
+ return hasFields;
}
bool doStructFieldsHaveSemantic(Type* type)
@@ -488,30 +500,58 @@ namespace Slang
}
}
+ bool canHaveVaryingInput = false;
+ switch (stage)
+ {
+ case Stage::Vertex:
+ case Stage::Fragment:
+ case Stage::Miss:
+ case Stage::AnyHit:
+ case Stage::ClosestHit:
+ case Stage::Callable:
+ case Stage::Geometry:
+ case Stage::Mesh:
+ case Stage::Hull:
+ case Stage::Domain:
+ canHaveVaryingInput = true;
+ break;
+ default:
+ break;
+ }
+
for (const auto& param : entryPointFuncDecl->getParameters())
{
if (isUniformParameterType(param->getType()))
{
// Automatically add `uniform` modifier to entry point parameters.
if (!param->hasModifier<HLSLUniformModifier>())
- addModifier(param, getCurrentASTBuilder()->create<HLSLUniformModifier>());
- }
- else if (isBuiltinParameterType(param->getType()))
- {
- }
- else
- {
- // For all non-uniform parameters of a general type, we require the parameter be associated with
- // a system value semantic.
- if (!param->hasModifier<HLSLUniformModifier>())
{
- if (!param->findModifier<HLSLSemantic>())
- {
- if (!doStructFieldsHaveSemantic(param->getType()))
- sink->diagnose(param, Diagnostics::nonUniformEntryPointParameterMustHaveSemantic, param->getName());
- }
+ addModifier(param, getCurrentASTBuilder()->create<HLSLUniformModifier>());
+ continue;
}
}
+
+ if (canHaveVaryingInput)
+ continue;
+
+ // If the stage doesn't allow varying input/output,
+ // we require the parameter to be associated with a system value semantic.
+ if (param->hasModifier<HLSLUniformModifier>())
+ continue;
+ if (param->findModifier<HLSLSemantic>())
+ continue;
+
+ bool isBuiltinType = isBuiltinParameterType(param->getType());
+ if (isBuiltinType)
+ continue;
+
+ if (doStructFieldsHaveSemantic(param->getType()))
+ continue;
+
+ // The user is defining a parameter with no 'uniform' modifier for a stage that doesn't support
+ // varying input/output. We will automatically convert it to a 'uniform' parameter, and diagnose a warning.
+ addModifier(param, getCurrentASTBuilder()->create<HLSLUniformModifier>());
+ sink->diagnose(param, Diagnostics::nonUniformEntryPointParameterTreatedAsUniform, param->getName());
}
for (auto target : linkage->targets)
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index cb3e39dd4..487f264d5 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -649,7 +649,8 @@ DIAGNOSTIC(38033, Error, cannotUseNoDiffInNonDifferentiableFunc, "cannot use 'no
DIAGNOSTIC(38034, Error, cannotUseConstRefOnDifferentiableParameter, "cannot use '__constref' on a differentiable parameter.")
DIAGNOSTIC(38034, Error, cannotUseConstRefOnDifferentiableMemberMethod, "cannot use '[constref]' on a differentiable member method of a differentiable type.")
-DIAGNOSTIC(38040, Error, nonUniformEntryPointParameterMustHaveSemantic, "non-uniform parameter '$0' must have a system-value semantic.")
+DIAGNOSTIC(38040, Warning, nonUniformEntryPointParameterTreatedAsUniform, "parameter '$0' is treated as 'uniform' because it does not have a system-value semantic.")
+
DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself")
DIAGNOSTIC(39999, Error, errorInImportedModule, "import of module '$0' failed because of a compilation error")
diff --git a/tests/bugs/invalid-entrypoint-param.slang b/tests/bugs/invalid-entrypoint-param.slang
deleted file mode 100644
index bafb0fb88..000000000
--- a/tests/bugs/invalid-entrypoint-param.slang
+++ /dev/null
@@ -1,17 +0,0 @@
-//TEST:SIMPLE(filecheck=CHECK): -target spirv
-
-// `TT` is not valid for defining a varying entrypoint parameter,
-// and we should diagnose an error.
-
-// CHECK: error 39028
-
-struct TT
-{
- Texture2D tex;
-}
-
-[numthreads(1, 1, 1)]
-void f(TT t)
-{
- return;
-} \ No newline at end of file
diff --git a/tests/language-feature/shader-params/entry-point-uniform-params-implicit.slang b/tests/language-feature/shader-params/entry-point-uniform-params-implicit.slang
new file mode 100644
index 000000000..5f8ac7edd
--- /dev/null
+++ b/tests/language-feature/shader-params/entry-point-uniform-params-implicit.slang
@@ -0,0 +1,42 @@
+// entry-point-uniform-params-implicit.slang
+
+// Test that slang can treat a compute shader parameter as `uniform` without explicit `uniform` keyword.
+
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -xslang -Wno-38040
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -xslang -Wno-38040
+//TEST:SIMPLE(filecheck=WARNING): -target spirv
+
+struct Data
+{
+ int a;
+ int b;
+}
+
+int test(int val, int a, int b)
+{
+ return a*(val+1) + b*(val+2);
+}
+
+[numthreads(4, 1, 1)]
+[shader("compute")]
+void computeMain(
+
+//TEST_INPUT:uniform(data=[256 1]):name=d
+// WARNING: ([[# @LINE+1]]): warning 38040
+ Data d,
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+ uniform RWStructuredBuffer<int> outputBuffer,
+
+ int3 dispatchThreadID : SV_DispatchThreadID)
+{
+ int tid = dispatchThreadID.x;
+ int inVal = tid;
+ int outVal = test(inVal, d.a, d.b);
+ outputBuffer[tid] = outVal;
+
+ // CHECK: 102
+ // CHECK: 203
+ // CHECK: 304
+ // CHECK: 405
+}
diff --git a/tests/reflection/attribute.slang.expected b/tests/reflection/attribute.slang.expected
index 9b51aae32..a5bce747b 100644
--- a/tests/reflection/attribute.slang.expected
+++ b/tests/reflection/attribute.slang.expected
@@ -288,8 +288,7 @@ standard output = {
]
}
],
- "stage": "compute",
- "binding": {"kind": "varyingInput", "index": 0},
+ "binding": {"kind": "uniform", "offset": 0, "size": 4},
"type": {
"kind": "scalar",
"scalarType": "float32"
@@ -304,8 +303,7 @@ standard output = {
]
}
],
- "stage": "compute",
- "binding": {"kind": "varyingInput", "index": 1},
+ "binding": {"kind": "uniform", "offset": 4, "size": 4},
"type": {
"kind": "scalar",
"scalarType": "float32"