summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorjsmall-nvidia <jsmall@nvidia.com>2019-02-08 11:41:26 -0500
committerGitHub <noreply@github.com>2019-02-08 11:41:26 -0500
commitef08bdaf501982ae24a73200e55f99157e4b7e6a (patch)
tree2f8d5f2100d5b9c8bb5c995b6417394e13b2eba6
parentc34a5e7ed2da03c9ceaddd167e4b0421246b0c25 (diff)
Hotfix/dispatch thread id improvements (#834)
* * Make vector comparisons out correct functions on glsl * Test for vector comparisons * Typo fixes * Glsl vector comparisons use functions. * Added a coercion test. * Do checking for the SV_DispatchThreadId type to see if it appears valid. * Fix typo * Make glsl do type conversion for SV_DispatchThreadID parameter. * Fix glsl to match func-resource-param-array with changes to how SV_DispatchThreadID changes.
-rw-r--r--source/slang/check.cpp62
-rw-r--r--source/slang/diagnostic-defs.h4
-rw-r--r--source/slang/ir-glsl-legalize.cpp5
-rw-r--r--tests/bugs/vec-compare.slang2
-rw-r--r--tests/cross-compile/func-resource-param-array.slang.glsl22
5 files changed, 84 insertions, 11 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index cce6a545a..8dfd7e640 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -9228,6 +9228,40 @@ namespace Slang
return entryPointFuncDecl;
}
+ static bool isValidThreadDispatchIDType(Type* type)
+ {
+ // Can accept a single int/unit
+ {
+ auto basicType = as<BasicExpressionType>(type);
+ if (basicType)
+ {
+ return (basicType->baseType == BaseType::Int || basicType->baseType == BaseType::UInt);
+ }
+ }
+ // Can be an int/uint vector from size 1 to 3
+ {
+ auto vectorType = as<VectorExpressionType>(type);
+ if (!vectorType)
+ {
+ return false;
+ }
+ auto elemCount = as<ConstantIntVal>(vectorType->elementCount);
+ if (elemCount->value < 1 || elemCount->value > 3)
+ {
+ return false;
+ }
+ // Must be a basic type
+ auto basicType = as<BasicExpressionType>(vectorType->elementType);
+ if (!basicType)
+ {
+ return false;
+ }
+
+ // Must be integral
+ return (basicType->baseType == BaseType::Int || basicType->baseType == BaseType::UInt);
+ }
+ }
+
// Validate that an entry point function conforms to any additional
// constraints based on the stage (and profile?) it specifies.
void validateEntryPoint(
@@ -9303,6 +9337,34 @@ namespace Slang
attr->patchConstantFuncDecl = funcDecl;
}
}
+ else if (entryPoint->getStage() == Stage::Compute)
+ {
+ auto funcDecl = entryPoint->getFuncDecl();
+
+ auto params = funcDecl->GetParameters();
+
+ for (const auto& param : params)
+ {
+ if (auto semantic = param->FindModifier<HLSLSimpleSemantic>())
+ {
+ const auto& semanticToken = semantic->name;
+
+ String lowerName = String(semanticToken.Content).ToLower();
+
+ if (lowerName == "sv_dispatchthreadid")
+ {
+ Type* paramType = param->getType();
+
+ if (!isValidThreadDispatchIDType(paramType))
+ {
+ String typeString = paramType->ToString();
+ sink->diagnose(param->loc, Diagnostics::invalidDispatchThreadIDType, typeString);
+ return;
+ }
+ }
+ }
+ }
+ }
}
// Given an `EntryPointRequest` specified via API or command line options,
diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h
index 9a8446a03..3e9d9043c 100644
--- a/source/slang/diagnostic-defs.h
+++ b/source/slang/diagnostic-defs.h
@@ -351,12 +351,14 @@ DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `
DIAGNOSTIC(38022, Error, cannotSpecializeGlobalGenericToItself, "the global type parameter '$0' cannot be specialized to itself")
DIAGNOSTIC(38023, Error, cannotSpecializeGlobalGenericToAnotherGenericParam, "the global type parameter '$0' cannot be specialized using another global type parameter ('$1')")
+
+DIAGNOSTIC(38024, Error, invalidDispatchThreadIDType, "parameter with SV_DispatchThreadID must be either scalar or vector (1 to 3) of uint/int but is $0");
+
DIAGNOSTIC(-1, Note, noteWhenCompilingEntryPoint, "when compiling entry point '$0'")
DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself")
DIAGNOSTIC(39999, Fatal, errorInImportedModule, "error in imported module, compilation ceased.")
-
// 39xxx - Type layout and parameter binding.
DIAGNOSTIC(39000, Error, conflictingExplicitBindingsForParameter, "conflicting explicit bindings for parameter '$0'")
diff --git a/source/slang/ir-glsl-legalize.cpp b/source/slang/ir-glsl-legalize.cpp
index 5b60314eb..57d493f1a 100644
--- a/source/slang/ir-glsl-legalize.cpp
+++ b/source/slang/ir-glsl-legalize.cpp
@@ -304,6 +304,9 @@ GLSLSystemValueInfo* getGLSLSystemValueInfo(
else if(semanticName == "sv_dispatchthreadid")
{
name = "gl_GlobalInvocationID";
+
+ auto builder = context->getBuilder();
+ requiredType = builder->getVectorType(builder->getBasicType(BaseType::UInt), builder->getIntValue(builder->getIntType(), 3));
}
else if(semanticName == "sv_domainlocation")
{
@@ -514,7 +517,7 @@ ScalarizedVal createSimpleGLSLGlobalVarying(
//
// Our IR global shader parameters are read-only, just
// like our IR function parameters, and need a wrapper
- // `Out<...>` type to represent otuputs.
+ // `Out<...>` type to represent outputs.
//
bool isOutput = kind == LayoutResourceKind::VaryingOutput;
IRType* paramType = isOutput ? builder->getOutType(type) : type;
diff --git a/tests/bugs/vec-compare.slang b/tests/bugs/vec-compare.slang
index 0eaec0191..b3075efe9 100644
--- a/tests/bugs/vec-compare.slang
+++ b/tests/bugs/vec-compare.slang
@@ -7,7 +7,7 @@
RWStructuredBuffer<int> outputBuffer;
[numthreads(4,4,1)]
-void computeMain(uint3 pixelIndex : SV_DispatchThreadID)
+void computeMain(uint2 pixelIndex : SV_DispatchThreadID)
{
// We will test floats, uints, and int vectors
diff --git a/tests/cross-compile/func-resource-param-array.slang.glsl b/tests/cross-compile/func-resource-param-array.slang.glsl
index 6224ccd1c..d7f9c17bc 100644
--- a/tests/cross-compile/func-resource-param-array.slang.glsl
+++ b/tests/cross-compile/func-resource-param-array.slang.glsl
@@ -25,11 +25,15 @@
#define g_c_t _S9
#define g_c_i _S10
#define g_c_j _S11
-#define tmp_f_a_ii _S12
-#define tmp_f_a_jj _S13
-#define tmp_f_b _S14
-#define tmp_g_b _S15
-#define tmp_g_c _S16
+
+#define tid _S12
+
+#define tmp_f_a_ii _S13
+#define tmp_f_a_jj _S14
+
+#define tmp_f_b _S15
+#define tmp_g_b _S16
+#define tmp_g_c _S17
layout(std430, binding = 0) buffer a_block {
int _data[];
@@ -67,9 +71,11 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main()
{
- uint ii = gl_GlobalInvocationID.x;
- uint jj = gl_GlobalInvocationID.y;
- uint kk = gl_GlobalInvocationID.z;
+ uvec3 tid = uvec3(gl_GlobalInvocationID);
+
+ uint ii = tid.x;
+ uint jj = tid.y;
+ uint kk = tid.z;
int tmp_f_a_ii = f_a(ii);