diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2025-09-04 04:05:26 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-09-03 20:05:26 +0000 |
| commit | a766d27447aa0fcf69334c0467d9b1124892e180 (patch) | |
| tree | 67ca5615e4a8c94d7454ee43375eeffc8c8a7d4c /source/slang/slang-check-decl.cpp | |
| parent | bf607e2f3fa183e9a2b18c7a98438a05247d6ed3 (diff) | |
Diagnose on structured buffers containing resources (#8222)
closes https://github.com/shader-slang/slang/issues/3313
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 105 |
1 files changed, 105 insertions, 0 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index e59cf6ad5..4711eaddd 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2698,6 +2698,8 @@ static Expr* constructDefaultInitExprForType(SemanticsVisitor* visitor, VarDeclB } } +void validateStructuredBufferElementType(SemanticsVisitor* visitor, VarDeclBase* varDecl); + void SemanticsDeclBodyVisitor::checkVarDeclCommon(VarDeclBase* varDecl) { DiagnoseIsAllowedInitExpr(varDecl, getSink()); @@ -2892,6 +2894,9 @@ void SemanticsDeclBodyVisitor::checkVarDeclCommon(VarDeclBase* varDecl) } } } + + validateStructuredBufferElementType(this, varDecl); + bool isGlobalOrLocalVar = !isGlobalShaderParameter(varDecl) && !as<ParamDecl>(varDecl) && (!parentDecl || isEffectivelyStatic(varDecl)); if (isGlobalOrLocalVar) @@ -15018,6 +15023,106 @@ bool isOpaqueHandleType(Type* type) return false; } +bool containsRecursiveTypeImpl(SemanticsVisitor* visitor, Type* type, HashSet<Decl*>& currentPath) +{ + // Skip modified types (const, etc.) + while (auto modifiedType = as<ModifiedType>(type)) + type = modifiedType->getBase(); + + // Check if this is a StructuredBuffer type and look inside it + if (auto structuredBufferType = as<HLSLStructuredBufferTypeBase>(type)) + { + return containsRecursiveTypeImpl( + visitor, + structuredBufferType->getElementType(), + currentPath); + } + + // Check if this is an array type and look inside it + if (auto arrayType = as<ArrayExpressionType>(type)) + { + return containsRecursiveTypeImpl(visitor, arrayType->getElementType(), currentPath); + } + + if (auto declRefType = as<DeclRefType>(type)) + { + auto typeDecl = declRefType->getDeclRef().getDecl(); + + // Check global cache first - if we've already fully analyzed this type, use that result + auto shared = visitor->getShared(); + if (auto cachedResult = shared->m_typeContainsRecursionCache.tryGetValue(typeDecl)) + { + return *cachedResult; + } + + // If we're currently exploring this type, we found a cycle! + if (currentPath.contains(typeDecl)) + { + return true; + } + + // Add to current exploration path + currentPath.add(typeDecl); + + bool hasRecursion = false; + + // Check members if it's an aggregate type + if (auto aggTypeDecl = declRefType->getDeclRef().as<AggTypeDecl>()) + { + for (auto member : aggTypeDecl.getDecl()->getMembersOfType<VarDeclBase>()) + { + if (isEffectivelyStatic(member)) + continue; + + if (containsRecursiveTypeImpl(visitor, member->getType(), currentPath)) + { + hasRecursion = true; + break; + } + } + } + + // Remove from current exploration path + currentPath.remove(typeDecl); + + // Cache the result globally + shared->m_typeContainsRecursionCache[typeDecl] = hasRecursion; + + return hasRecursion; + } + + return false; +} + +bool containsRecursiveType(SemanticsVisitor* visitor, Type* type) +{ + HashSet<Decl*> currentPath; + return containsRecursiveTypeImpl(visitor, type, currentPath); +} + +void validateStructuredBufferElementType(SemanticsVisitor* visitor, VarDeclBase* varDecl) +{ + auto type = unwrapArrayType(varDecl->getType()); + + // Check if this is a StructuredBuffer type + auto structuredBufferType = as<HLSLStructuredBufferTypeBase>(type); + + if (!structuredBufferType) + return; + + // Get the element type + auto elementType = structuredBufferType->getElementType(); + + // Check if the element type contains recursive references + if (containsRecursiveType(visitor, elementType)) + { + visitor->getSink()->diagnose( + varDecl->loc, + Diagnostics::recursiveTypesFoundInStructuredBuffer, + elementType); + } +} + void diagnoseMissingCapabilityProvenance( CompilerOptionSet& optionSet, DiagnosticSink* sink, |
