diff options
| author | jsmall-nvidia <jsmall@nvidia.com> | 2019-02-08 11:41:26 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2019-02-08 11:41:26 -0500 |
| commit | ef08bdaf501982ae24a73200e55f99157e4b7e6a (patch) | |
| tree | 2f8d5f2100d5b9c8bb5c995b6417394e13b2eba6 | |
| parent | c34a5e7ed2da03c9ceaddd167e4b0421246b0c25 (diff) | |
Hotfix/dispatch thread id improvements (#834)
* * Make vector comparisons out correct functions on glsl
* Test for vector comparisons
* Typo fixes
* Glsl vector comparisons use functions.
* Added a coercion test.
* Do checking for the SV_DispatchThreadId type to see if it appears valid.
* Fix typo
* Make glsl do type conversion for SV_DispatchThreadID parameter.
* Fix glsl to match func-resource-param-array with changes to how SV_DispatchThreadID changes.
| -rw-r--r-- | source/slang/check.cpp | 62 | ||||
| -rw-r--r-- | source/slang/diagnostic-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/ir-glsl-legalize.cpp | 5 | ||||
| -rw-r--r-- | tests/bugs/vec-compare.slang | 2 | ||||
| -rw-r--r-- | tests/cross-compile/func-resource-param-array.slang.glsl | 22 |
5 files changed, 84 insertions, 11 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index cce6a545a..8dfd7e640 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -9228,6 +9228,40 @@ namespace Slang return entryPointFuncDecl; } + static bool isValidThreadDispatchIDType(Type* type) + { + // Can accept a single int/unit + { + auto basicType = as<BasicExpressionType>(type); + if (basicType) + { + return (basicType->baseType == BaseType::Int || basicType->baseType == BaseType::UInt); + } + } + // Can be an int/uint vector from size 1 to 3 + { + auto vectorType = as<VectorExpressionType>(type); + if (!vectorType) + { + return false; + } + auto elemCount = as<ConstantIntVal>(vectorType->elementCount); + if (elemCount->value < 1 || elemCount->value > 3) + { + return false; + } + // Must be a basic type + auto basicType = as<BasicExpressionType>(vectorType->elementType); + if (!basicType) + { + return false; + } + + // Must be integral + return (basicType->baseType == BaseType::Int || basicType->baseType == BaseType::UInt); + } + } + // Validate that an entry point function conforms to any additional // constraints based on the stage (and profile?) it specifies. void validateEntryPoint( @@ -9303,6 +9337,34 @@ namespace Slang attr->patchConstantFuncDecl = funcDecl; } } + else if (entryPoint->getStage() == Stage::Compute) + { + auto funcDecl = entryPoint->getFuncDecl(); + + auto params = funcDecl->GetParameters(); + + for (const auto& param : params) + { + if (auto semantic = param->FindModifier<HLSLSimpleSemantic>()) + { + const auto& semanticToken = semantic->name; + + String lowerName = String(semanticToken.Content).ToLower(); + + if (lowerName == "sv_dispatchthreadid") + { + Type* paramType = param->getType(); + + if (!isValidThreadDispatchIDType(paramType)) + { + String typeString = paramType->ToString(); + sink->diagnose(param->loc, Diagnostics::invalidDispatchThreadIDType, typeString); + return; + } + } + } + } + } } // Given an `EntryPointRequest` specified via API or command line options, diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h index 9a8446a03..3e9d9043c 100644 --- a/source/slang/diagnostic-defs.h +++ b/source/slang/diagnostic-defs.h @@ -351,12 +351,14 @@ DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument ` DIAGNOSTIC(38022, Error, cannotSpecializeGlobalGenericToItself, "the global type parameter '$0' cannot be specialized to itself") DIAGNOSTIC(38023, Error, cannotSpecializeGlobalGenericToAnotherGenericParam, "the global type parameter '$0' cannot be specialized using another global type parameter ('$1')") + +DIAGNOSTIC(38024, Error, invalidDispatchThreadIDType, "parameter with SV_DispatchThreadID must be either scalar or vector (1 to 3) of uint/int but is $0"); + DIAGNOSTIC(-1, Note, noteWhenCompilingEntryPoint, "when compiling entry point '$0'") DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself") DIAGNOSTIC(39999, Fatal, errorInImportedModule, "error in imported module, compilation ceased.") - // 39xxx - Type layout and parameter binding. DIAGNOSTIC(39000, Error, conflictingExplicitBindingsForParameter, "conflicting explicit bindings for parameter '$0'") diff --git a/source/slang/ir-glsl-legalize.cpp b/source/slang/ir-glsl-legalize.cpp index 5b60314eb..57d493f1a 100644 --- a/source/slang/ir-glsl-legalize.cpp +++ b/source/slang/ir-glsl-legalize.cpp @@ -304,6 +304,9 @@ GLSLSystemValueInfo* getGLSLSystemValueInfo( else if(semanticName == "sv_dispatchthreadid") { name = "gl_GlobalInvocationID"; + + auto builder = context->getBuilder(); + requiredType = builder->getVectorType(builder->getBasicType(BaseType::UInt), builder->getIntValue(builder->getIntType(), 3)); } else if(semanticName == "sv_domainlocation") { @@ -514,7 +517,7 @@ ScalarizedVal createSimpleGLSLGlobalVarying( // // Our IR global shader parameters are read-only, just // like our IR function parameters, and need a wrapper - // `Out<...>` type to represent otuputs. + // `Out<...>` type to represent outputs. // bool isOutput = kind == LayoutResourceKind::VaryingOutput; IRType* paramType = isOutput ? builder->getOutType(type) : type; diff --git a/tests/bugs/vec-compare.slang b/tests/bugs/vec-compare.slang index 0eaec0191..b3075efe9 100644 --- a/tests/bugs/vec-compare.slang +++ b/tests/bugs/vec-compare.slang @@ -7,7 +7,7 @@ RWStructuredBuffer<int> outputBuffer; [numthreads(4,4,1)] -void computeMain(uint3 pixelIndex : SV_DispatchThreadID) +void computeMain(uint2 pixelIndex : SV_DispatchThreadID) { // We will test floats, uints, and int vectors diff --git a/tests/cross-compile/func-resource-param-array.slang.glsl b/tests/cross-compile/func-resource-param-array.slang.glsl index 6224ccd1c..d7f9c17bc 100644 --- a/tests/cross-compile/func-resource-param-array.slang.glsl +++ b/tests/cross-compile/func-resource-param-array.slang.glsl @@ -25,11 +25,15 @@ #define g_c_t _S9 #define g_c_i _S10 #define g_c_j _S11 -#define tmp_f_a_ii _S12 -#define tmp_f_a_jj _S13 -#define tmp_f_b _S14 -#define tmp_g_b _S15 -#define tmp_g_c _S16 + +#define tid _S12 + +#define tmp_f_a_ii _S13 +#define tmp_f_a_jj _S14 + +#define tmp_f_b _S15 +#define tmp_g_b _S16 +#define tmp_g_c _S17 layout(std430, binding = 0) buffer a_block { int _data[]; @@ -67,9 +71,11 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; void main() { - uint ii = gl_GlobalInvocationID.x; - uint jj = gl_GlobalInvocationID.y; - uint kk = gl_GlobalInvocationID.z; + uvec3 tid = uvec3(gl_GlobalInvocationID); + + uint ii = tid.x; + uint jj = tid.y; + uint kk = tid.z; int tmp_f_a_ii = f_a(ii); |
