summaryrefslogtreecommitdiffstats
path: root/source/slang/check.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/check.cpp')
-rw-r--r--source/slang/check.cpp62
1 files changed, 62 insertions, 0 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index cce6a545a..8dfd7e640 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -9228,6 +9228,40 @@ namespace Slang
return entryPointFuncDecl;
}
+ static bool isValidThreadDispatchIDType(Type* type)
+ {
+ // Can accept a single int/unit
+ {
+ auto basicType = as<BasicExpressionType>(type);
+ if (basicType)
+ {
+ return (basicType->baseType == BaseType::Int || basicType->baseType == BaseType::UInt);
+ }
+ }
+ // Can be an int/uint vector from size 1 to 3
+ {
+ auto vectorType = as<VectorExpressionType>(type);
+ if (!vectorType)
+ {
+ return false;
+ }
+ auto elemCount = as<ConstantIntVal>(vectorType->elementCount);
+ if (elemCount->value < 1 || elemCount->value > 3)
+ {
+ return false;
+ }
+ // Must be a basic type
+ auto basicType = as<BasicExpressionType>(vectorType->elementType);
+ if (!basicType)
+ {
+ return false;
+ }
+
+ // Must be integral
+ return (basicType->baseType == BaseType::Int || basicType->baseType == BaseType::UInt);
+ }
+ }
+
// Validate that an entry point function conforms to any additional
// constraints based on the stage (and profile?) it specifies.
void validateEntryPoint(
@@ -9303,6 +9337,34 @@ namespace Slang
attr->patchConstantFuncDecl = funcDecl;
}
}
+ else if (entryPoint->getStage() == Stage::Compute)
+ {
+ auto funcDecl = entryPoint->getFuncDecl();
+
+ auto params = funcDecl->GetParameters();
+
+ for (const auto& param : params)
+ {
+ if (auto semantic = param->FindModifier<HLSLSimpleSemantic>())
+ {
+ const auto& semanticToken = semantic->name;
+
+ String lowerName = String(semanticToken.Content).ToLower();
+
+ if (lowerName == "sv_dispatchthreadid")
+ {
+ Type* paramType = param->getType();
+
+ if (!isValidThreadDispatchIDType(paramType))
+ {
+ String typeString = paramType->ToString();
+ sink->diagnose(param->loc, Diagnostics::invalidDispatchThreadIDType, typeString);
+ return;
+ }
+ }
+ }
+ }
+ }
}
// Given an `EntryPointRequest` specified via API or command line options,