diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/modifier-defs.h | 5 | ||||
| -rw-r--r-- | source/slang/parser.cpp | 3 | ||||
| -rw-r--r-- | source/slang/reflection.cpp | 55 |
3 files changed, 58 insertions, 5 deletions
diff --git a/source/slang/modifier-defs.h b/source/slang/modifier-defs.h index 665551b35..fadc68695 100644 --- a/source/slang/modifier-defs.h +++ b/source/slang/modifier-defs.h @@ -111,6 +111,11 @@ SIMPLE_SYNTAX_CLASS(GLSLSetLayoutModifier , GLSLParsedLayoutModifier) SIMPLE_SYNTAX_CLASS(GLSLLocationLayoutModifier , GLSLParsedLayoutModifier) SIMPLE_SYNTAX_CLASS(GLSLPushConstantLayoutModifier, GLSLParsedLayoutModifier) +SIMPLE_SYNTAX_CLASS(GLSLLocalSizeLayoutModifier, GLSLUnparsedLayoutModifier) +SIMPLE_SYNTAX_CLASS(GLSLLocalSizeXLayoutModifier, GLSLLocalSizeLayoutModifier) +SIMPLE_SYNTAX_CLASS(GLSLLocalSizeYLayoutModifier, GLSLLocalSizeLayoutModifier) +SIMPLE_SYNTAX_CLASS(GLSLLocalSizeZLayoutModifier, GLSLLocalSizeLayoutModifier) + // A catch-all for single-keyword modifiers SIMPLE_SYNTAX_CLASS(SimpleModifier, Modifier) diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index e0ff85164..1ad0c6b94 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -748,6 +748,9 @@ namespace Slang CASE(set, GLSLSetLayoutModifier); CASE(location, GLSLLocationLayoutModifier); CASE(push_constant, GLSLPushConstantLayoutModifier); + CASE(local_size_x, GLSLLocalSizeXLayoutModifier); + CASE(local_size_y, GLSLLocalSizeYLayoutModifier); + CASE(local_size_z, GLSLLocalSizeZLayoutModifier); #undef CASE else diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp index 3c89194c8..aaeae5595 100644 --- a/source/slang/reflection.cpp +++ b/source/slang/reflection.cpp @@ -204,7 +204,7 @@ SLANG_API size_t spReflectionType_GetElementCount(SlangReflectionType* inType) if(auto arrayType = dynamic_cast<ArrayExpressionType*>(type)) { - return (size_t) GetIntVal(arrayType->ArrayLength); + return arrayType->ArrayLength ? (size_t) GetIntVal(arrayType->ArrayLength) : 0; } else if( auto vectorType = dynamic_cast<VectorExpressionType*>(type)) { @@ -751,12 +751,43 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( auto entryPointFunc = entryPointLayout->entryPoint; if(!entryPointFunc) return; + SlangUInt sizeAlongAxis[3] = { 1, 1, 1 }; + + // First look for the HLSL case, where we have an attribute attached to the entry point function auto numThreadsAttribute = entryPointFunc->FindModifier<HLSLNumThreadsAttribute>(); - if(!numThreadsAttribute) return; + if (numThreadsAttribute) + { + sizeAlongAxis[0] = numThreadsAttribute->x; + sizeAlongAxis[1] = numThreadsAttribute->y; + sizeAlongAxis[2] = numThreadsAttribute->z; + } + else + { + // Fall back to the GLSL case, which requires a search over global-scope declarations + // to look for anything with the `local_size_*` qualifier + auto module = dynamic_cast<ProgramSyntaxNode*>(entryPointFunc->ParentDecl); + if (module) + { + for (auto dd : module->Members) + { + for (auto mod : dd->GetModifiersOfType<GLSLLocalSizeLayoutModifier>()) + { + if (auto xMod = dynamic_cast<GLSLLocalSizeXLayoutModifier*>(mod)) + sizeAlongAxis[0] = (SlangUInt) getIntegerLiteralValue(xMod->valToken); + else if (auto yMod = dynamic_cast<GLSLLocalSizeYLayoutModifier*>(mod)) + sizeAlongAxis[1] = (SlangUInt) getIntegerLiteralValue(yMod->valToken); + else if (auto zMod = dynamic_cast<GLSLLocalSizeZLayoutModifier*>(mod)) + sizeAlongAxis[2] = (SlangUInt) getIntegerLiteralValue(zMod->valToken); + } + } + } + } + + // - if(axisCount > 0) outSizeAlongAxis[0] = numThreadsAttribute->x; - if(axisCount > 1) outSizeAlongAxis[1] = numThreadsAttribute->y; - if(axisCount > 2) outSizeAlongAxis[2] = numThreadsAttribute->z; + if(axisCount > 0) outSizeAlongAxis[0] = sizeAlongAxis[0]; + if(axisCount > 1) outSizeAlongAxis[1] = sizeAlongAxis[1]; + if(axisCount > 2) outSizeAlongAxis[2] = sizeAlongAxis[2]; for( SlangUInt aa = 3; aa < axisCount; ++aa ) { outSizeAlongAxis[aa] = 1; @@ -1526,6 +1557,20 @@ static void emitReflectionEntryPointJSON( write(writer, ",\n\"usesAnySampleRateInput\": true"); } + if (entryPoint->getStage() == SLANG_STAGE_COMPUTE) + { + SlangUInt threadGroupSize[3]; + entryPoint->getComputeThreadGroupSize(3, threadGroupSize); + + write(writer, ",\n\"threadGroupSize\": ["); + for (int ii = 0; ii < 3; ++ii) + { + if (ii != 0) write(writer, ", "); + write(writer, threadGroupSize[ii]); + } + write(writer, "]"); + } + dedent(writer); write(writer, "\n}"); } |
