summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2025-09-04 04:05:26 +0800
committerGitHub <noreply@github.com>2025-09-03 20:05:26 +0000
commita766d27447aa0fcf69334c0467d9b1124892e180 (patch)
tree67ca5615e4a8c94d7454ee43375eeffc8c8a7d4c /source/slang/slang-check-decl.cpp
parentbf607e2f3fa183e9a2b18c7a98438a05247d6ed3 (diff)
Diagnose on structured buffers containing resources (#8222)
closes https://github.com/shader-slang/slang/issues/3313
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp105
1 files changed, 105 insertions, 0 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index e59cf6ad5..4711eaddd 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -2698,6 +2698,8 @@ static Expr* constructDefaultInitExprForType(SemanticsVisitor* visitor, VarDeclB
}
}
+void validateStructuredBufferElementType(SemanticsVisitor* visitor, VarDeclBase* varDecl);
+
void SemanticsDeclBodyVisitor::checkVarDeclCommon(VarDeclBase* varDecl)
{
DiagnoseIsAllowedInitExpr(varDecl, getSink());
@@ -2892,6 +2894,9 @@ void SemanticsDeclBodyVisitor::checkVarDeclCommon(VarDeclBase* varDecl)
}
}
}
+
+ validateStructuredBufferElementType(this, varDecl);
+
bool isGlobalOrLocalVar = !isGlobalShaderParameter(varDecl) && !as<ParamDecl>(varDecl) &&
(!parentDecl || isEffectivelyStatic(varDecl));
if (isGlobalOrLocalVar)
@@ -15018,6 +15023,106 @@ bool isOpaqueHandleType(Type* type)
return false;
}
+bool containsRecursiveTypeImpl(SemanticsVisitor* visitor, Type* type, HashSet<Decl*>& currentPath)
+{
+ // Skip modified types (const, etc.)
+ while (auto modifiedType = as<ModifiedType>(type))
+ type = modifiedType->getBase();
+
+ // Check if this is a StructuredBuffer type and look inside it
+ if (auto structuredBufferType = as<HLSLStructuredBufferTypeBase>(type))
+ {
+ return containsRecursiveTypeImpl(
+ visitor,
+ structuredBufferType->getElementType(),
+ currentPath);
+ }
+
+ // Check if this is an array type and look inside it
+ if (auto arrayType = as<ArrayExpressionType>(type))
+ {
+ return containsRecursiveTypeImpl(visitor, arrayType->getElementType(), currentPath);
+ }
+
+ if (auto declRefType = as<DeclRefType>(type))
+ {
+ auto typeDecl = declRefType->getDeclRef().getDecl();
+
+ // Check global cache first - if we've already fully analyzed this type, use that result
+ auto shared = visitor->getShared();
+ if (auto cachedResult = shared->m_typeContainsRecursionCache.tryGetValue(typeDecl))
+ {
+ return *cachedResult;
+ }
+
+ // If we're currently exploring this type, we found a cycle!
+ if (currentPath.contains(typeDecl))
+ {
+ return true;
+ }
+
+ // Add to current exploration path
+ currentPath.add(typeDecl);
+
+ bool hasRecursion = false;
+
+ // Check members if it's an aggregate type
+ if (auto aggTypeDecl = declRefType->getDeclRef().as<AggTypeDecl>())
+ {
+ for (auto member : aggTypeDecl.getDecl()->getMembersOfType<VarDeclBase>())
+ {
+ if (isEffectivelyStatic(member))
+ continue;
+
+ if (containsRecursiveTypeImpl(visitor, member->getType(), currentPath))
+ {
+ hasRecursion = true;
+ break;
+ }
+ }
+ }
+
+ // Remove from current exploration path
+ currentPath.remove(typeDecl);
+
+ // Cache the result globally
+ shared->m_typeContainsRecursionCache[typeDecl] = hasRecursion;
+
+ return hasRecursion;
+ }
+
+ return false;
+}
+
+bool containsRecursiveType(SemanticsVisitor* visitor, Type* type)
+{
+ HashSet<Decl*> currentPath;
+ return containsRecursiveTypeImpl(visitor, type, currentPath);
+}
+
+void validateStructuredBufferElementType(SemanticsVisitor* visitor, VarDeclBase* varDecl)
+{
+ auto type = unwrapArrayType(varDecl->getType());
+
+ // Check if this is a StructuredBuffer type
+ auto structuredBufferType = as<HLSLStructuredBufferTypeBase>(type);
+
+ if (!structuredBufferType)
+ return;
+
+ // Get the element type
+ auto elementType = structuredBufferType->getElementType();
+
+ // Check if the element type contains recursive references
+ if (containsRecursiveType(visitor, elementType))
+ {
+ visitor->getSink()->diagnose(
+ varDecl->loc,
+ Diagnostics::recursiveTypesFoundInStructuredBuffer,
+ elementType);
+ }
+}
+
void diagnoseMissingCapabilityProvenance(
CompilerOptionSet& optionSet,
DiagnosticSink* sink,