diff options
Diffstat (limited to 'source/slang/reflection.cpp')
| -rw-r--r-- | source/slang/reflection.cpp | 55 |
1 files changed, 50 insertions, 5 deletions
diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp index 3c89194c8..aaeae5595 100644 --- a/source/slang/reflection.cpp +++ b/source/slang/reflection.cpp @@ -204,7 +204,7 @@ SLANG_API size_t spReflectionType_GetElementCount(SlangReflectionType* inType) if(auto arrayType = dynamic_cast<ArrayExpressionType*>(type)) { - return (size_t) GetIntVal(arrayType->ArrayLength); + return arrayType->ArrayLength ? (size_t) GetIntVal(arrayType->ArrayLength) : 0; } else if( auto vectorType = dynamic_cast<VectorExpressionType*>(type)) { @@ -751,12 +751,43 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( auto entryPointFunc = entryPointLayout->entryPoint; if(!entryPointFunc) return; + SlangUInt sizeAlongAxis[3] = { 1, 1, 1 }; + + // First look for the HLSL case, where we have an attribute attached to the entry point function auto numThreadsAttribute = entryPointFunc->FindModifier<HLSLNumThreadsAttribute>(); - if(!numThreadsAttribute) return; + if (numThreadsAttribute) + { + sizeAlongAxis[0] = numThreadsAttribute->x; + sizeAlongAxis[1] = numThreadsAttribute->y; + sizeAlongAxis[2] = numThreadsAttribute->z; + } + else + { + // Fall back to the GLSL case, which requires a search over global-scope declarations + // to look for anything with the `local_size_*` qualifier + auto module = dynamic_cast<ProgramSyntaxNode*>(entryPointFunc->ParentDecl); + if (module) + { + for (auto dd : module->Members) + { + for (auto mod : dd->GetModifiersOfType<GLSLLocalSizeLayoutModifier>()) + { + if (auto xMod = dynamic_cast<GLSLLocalSizeXLayoutModifier*>(mod)) + sizeAlongAxis[0] = (SlangUInt) getIntegerLiteralValue(xMod->valToken); + else if (auto yMod = dynamic_cast<GLSLLocalSizeYLayoutModifier*>(mod)) + sizeAlongAxis[1] = (SlangUInt) getIntegerLiteralValue(yMod->valToken); + else if (auto zMod = dynamic_cast<GLSLLocalSizeZLayoutModifier*>(mod)) + sizeAlongAxis[2] = (SlangUInt) getIntegerLiteralValue(zMod->valToken); + } + } + } + } + + // - if(axisCount > 0) outSizeAlongAxis[0] = numThreadsAttribute->x; - if(axisCount > 1) outSizeAlongAxis[1] = numThreadsAttribute->y; - if(axisCount > 2) outSizeAlongAxis[2] = numThreadsAttribute->z; + if(axisCount > 0) outSizeAlongAxis[0] = sizeAlongAxis[0]; + if(axisCount > 1) outSizeAlongAxis[1] = sizeAlongAxis[1]; + if(axisCount > 2) outSizeAlongAxis[2] = sizeAlongAxis[2]; for( SlangUInt aa = 3; aa < axisCount; ++aa ) { outSizeAlongAxis[aa] = 1; @@ -1526,6 +1557,20 @@ static void emitReflectionEntryPointJSON( write(writer, ",\n\"usesAnySampleRateInput\": true"); } + if (entryPoint->getStage() == SLANG_STAGE_COMPUTE) + { + SlangUInt threadGroupSize[3]; + entryPoint->getComputeThreadGroupSize(3, threadGroupSize); + + write(writer, ",\n\"threadGroupSize\": ["); + for (int ii = 0; ii < 3; ++ii) + { + if (ii != 0) write(writer, ", "); + write(writer, threadGroupSize[ii]); + } + write(writer, "]"); + } + dedent(writer); write(writer, "\n}"); } |
