summaryrefslogtreecommitdiffstats
path: root/source/slang/reflection.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/reflection.cpp')
-rw-r--r--source/slang/reflection.cpp55
1 files changed, 50 insertions, 5 deletions
diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp
index 3c89194c8..aaeae5595 100644
--- a/source/slang/reflection.cpp
+++ b/source/slang/reflection.cpp
@@ -204,7 +204,7 @@ SLANG_API size_t spReflectionType_GetElementCount(SlangReflectionType* inType)
if(auto arrayType = dynamic_cast<ArrayExpressionType*>(type))
{
- return (size_t) GetIntVal(arrayType->ArrayLength);
+ return arrayType->ArrayLength ? (size_t) GetIntVal(arrayType->ArrayLength) : 0;
}
else if( auto vectorType = dynamic_cast<VectorExpressionType*>(type))
{
@@ -751,12 +751,43 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize(
auto entryPointFunc = entryPointLayout->entryPoint;
if(!entryPointFunc) return;
+ SlangUInt sizeAlongAxis[3] = { 1, 1, 1 };
+
+ // First look for the HLSL case, where we have an attribute attached to the entry point function
auto numThreadsAttribute = entryPointFunc->FindModifier<HLSLNumThreadsAttribute>();
- if(!numThreadsAttribute) return;
+ if (numThreadsAttribute)
+ {
+ sizeAlongAxis[0] = numThreadsAttribute->x;
+ sizeAlongAxis[1] = numThreadsAttribute->y;
+ sizeAlongAxis[2] = numThreadsAttribute->z;
+ }
+ else
+ {
+ // Fall back to the GLSL case, which requires a search over global-scope declarations
+ // to look for anything with the `local_size_*` qualifier
+ auto module = dynamic_cast<ProgramSyntaxNode*>(entryPointFunc->ParentDecl);
+ if (module)
+ {
+ for (auto dd : module->Members)
+ {
+ for (auto mod : dd->GetModifiersOfType<GLSLLocalSizeLayoutModifier>())
+ {
+ if (auto xMod = dynamic_cast<GLSLLocalSizeXLayoutModifier*>(mod))
+ sizeAlongAxis[0] = (SlangUInt) getIntegerLiteralValue(xMod->valToken);
+ else if (auto yMod = dynamic_cast<GLSLLocalSizeYLayoutModifier*>(mod))
+ sizeAlongAxis[1] = (SlangUInt) getIntegerLiteralValue(yMod->valToken);
+ else if (auto zMod = dynamic_cast<GLSLLocalSizeZLayoutModifier*>(mod))
+ sizeAlongAxis[2] = (SlangUInt) getIntegerLiteralValue(zMod->valToken);
+ }
+ }
+ }
+ }
+
+ //
- if(axisCount > 0) outSizeAlongAxis[0] = numThreadsAttribute->x;
- if(axisCount > 1) outSizeAlongAxis[1] = numThreadsAttribute->y;
- if(axisCount > 2) outSizeAlongAxis[2] = numThreadsAttribute->z;
+ if(axisCount > 0) outSizeAlongAxis[0] = sizeAlongAxis[0];
+ if(axisCount > 1) outSizeAlongAxis[1] = sizeAlongAxis[1];
+ if(axisCount > 2) outSizeAlongAxis[2] = sizeAlongAxis[2];
for( SlangUInt aa = 3; aa < axisCount; ++aa )
{
outSizeAlongAxis[aa] = 1;
@@ -1526,6 +1557,20 @@ static void emitReflectionEntryPointJSON(
write(writer, ",\n\"usesAnySampleRateInput\": true");
}
+ if (entryPoint->getStage() == SLANG_STAGE_COMPUTE)
+ {
+ SlangUInt threadGroupSize[3];
+ entryPoint->getComputeThreadGroupSize(3, threadGroupSize);
+
+ write(writer, ",\n\"threadGroupSize\": [");
+ for (int ii = 0; ii < 3; ++ii)
+ {
+ if (ii != 0) write(writer, ", ");
+ write(writer, threadGroupSize[ii]);
+ }
+ write(writer, "]");
+ }
+
dedent(writer);
write(writer, "\n}");
}