summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2025-09-04 04:05:26 +0800
committerGitHub <noreply@github.com>2025-09-03 20:05:26 +0000
commita766d27447aa0fcf69334c0467d9b1124892e180 (patch)
tree67ca5615e4a8c94d7454ee43375eeffc8c8a7d4c /source/slang
parentbf607e2f3fa183e9a2b18c7a98438a05247d6ed3 (diff)
Diagnose on structured buffers containing resources (#8222)
closes https://github.com/shader-slang/slang/issues/3313
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ast-base.h14
-rw-r--r--source/slang/slang-check-decl.cpp105
-rw-r--r--source/slang/slang-check-impl.h15
-rw-r--r--source/slang/slang-diagnostic-defs.h12
-rw-r--r--source/slang/slang-emit.cpp3
-rw-r--r--source/slang/slang-ir-validate.cpp145
-rw-r--r--source/slang/slang-ir-validate.h5
-rw-r--r--source/slang/slang-ir.cpp9
-rw-r--r--source/slang/slang-ir.h1
9 files changed, 296 insertions, 13 deletions
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h
index dbbfe6b92..828bbdb5c 100644
--- a/source/slang/slang-ast-base.h
+++ b/source/slang/slang-ast-base.h
@@ -598,6 +598,20 @@ protected:
ASTBuilder* m_astBuilderForReflection;
};
+struct TypePair
+{
+ Type* type0;
+ Type* type1;
+ HashCode getHashCode() const
+ {
+ return combineHash(Slang::getHashCode(type0), Slang::getHashCode(type1));
+ }
+ bool operator==(const TypePair& other) const
+ {
+ return type0 == other.type0 && type1 == other.type1;
+ }
+};
+
template<typename T>
SLANG_FORCE_INLINE T* as(Type* obj)
{
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index e59cf6ad5..4711eaddd 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -2698,6 +2698,8 @@ static Expr* constructDefaultInitExprForType(SemanticsVisitor* visitor, VarDeclB
}
}
+void validateStructuredBufferElementType(SemanticsVisitor* visitor, VarDeclBase* varDecl);
+
void SemanticsDeclBodyVisitor::checkVarDeclCommon(VarDeclBase* varDecl)
{
DiagnoseIsAllowedInitExpr(varDecl, getSink());
@@ -2892,6 +2894,9 @@ void SemanticsDeclBodyVisitor::checkVarDeclCommon(VarDeclBase* varDecl)
}
}
}
+
+ validateStructuredBufferElementType(this, varDecl);
+
bool isGlobalOrLocalVar = !isGlobalShaderParameter(varDecl) && !as<ParamDecl>(varDecl) &&
(!parentDecl || isEffectivelyStatic(varDecl));
if (isGlobalOrLocalVar)
@@ -15018,6 +15023,106 @@ bool isOpaqueHandleType(Type* type)
return false;
}
+bool containsRecursiveTypeImpl(SemanticsVisitor* visitor, Type* type, HashSet<Decl*>& currentPath)
+{
+ // Skip modified types (const, etc.)
+ while (auto modifiedType = as<ModifiedType>(type))
+ type = modifiedType->getBase();
+
+ // Check if this is a StructuredBuffer type and look inside it
+ if (auto structuredBufferType = as<HLSLStructuredBufferTypeBase>(type))
+ {
+ return containsRecursiveTypeImpl(
+ visitor,
+ structuredBufferType->getElementType(),
+ currentPath);
+ }
+
+ // Check if this is an array type and look inside it
+ if (auto arrayType = as<ArrayExpressionType>(type))
+ {
+ return containsRecursiveTypeImpl(visitor, arrayType->getElementType(), currentPath);
+ }
+
+ if (auto declRefType = as<DeclRefType>(type))
+ {
+ auto typeDecl = declRefType->getDeclRef().getDecl();
+
+ // Check global cache first - if we've already fully analyzed this type, use that result
+ auto shared = visitor->getShared();
+ if (auto cachedResult = shared->m_typeContainsRecursionCache.tryGetValue(typeDecl))
+ {
+ return *cachedResult;
+ }
+
+ // If we're currently exploring this type, we found a cycle!
+ if (currentPath.contains(typeDecl))
+ {
+ return true;
+ }
+
+ // Add to current exploration path
+ currentPath.add(typeDecl);
+
+ bool hasRecursion = false;
+
+ // Check members if it's an aggregate type
+ if (auto aggTypeDecl = declRefType->getDeclRef().as<AggTypeDecl>())
+ {
+ for (auto member : aggTypeDecl.getDecl()->getMembersOfType<VarDeclBase>())
+ {
+ if (isEffectivelyStatic(member))
+ continue;
+
+ if (containsRecursiveTypeImpl(visitor, member->getType(), currentPath))
+ {
+ hasRecursion = true;
+ break;
+ }
+ }
+ }
+
+ // Remove from current exploration path
+ currentPath.remove(typeDecl);
+
+ // Cache the result globally
+ shared->m_typeContainsRecursionCache[typeDecl] = hasRecursion;
+
+ return hasRecursion;
+ }
+
+ return false;
+}
+
+bool containsRecursiveType(SemanticsVisitor* visitor, Type* type)
+{
+ HashSet<Decl*> currentPath;
+ return containsRecursiveTypeImpl(visitor, type, currentPath);
+}
+
+void validateStructuredBufferElementType(SemanticsVisitor* visitor, VarDeclBase* varDecl)
+{
+ auto type = unwrapArrayType(varDecl->getType());
+
+ // Check if this is a StructuredBuffer type
+ auto structuredBufferType = as<HLSLStructuredBufferTypeBase>(type);
+
+ if (!structuredBufferType)
+ return;
+
+ // Get the element type
+ auto elementType = structuredBufferType->getElementType();
+
+ // Check if the element type contains recursive references
+ if (containsRecursiveType(visitor, elementType))
+ {
+ visitor->getSink()->diagnose(
+ varDecl->loc,
+ Diagnostics::recursiveTypesFoundInStructuredBuffer,
+ elementType);
+ }
+}
+
void diagnoseMissingCapabilityProvenance(
CompilerOptionSet& optionSet,
DiagnosticSink* sink,
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index e6c66ddd3..d82bf4427 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -699,6 +699,8 @@ struct SharedSemanticsContext : public RefObject
GLSLBindingOffsetTracker m_glslBindingOffsetTracker;
+ Dictionary<Decl*, bool> m_typeContainsRecursionCache;
+
public:
SharedSemanticsContext(
Linkage* linkage,
@@ -919,19 +921,6 @@ private:
FacetList baseFacets,
FacetList::Builder& ioMergedFacets);
- struct TypePair
- {
- Type* type0;
- Type* type1;
- HashCode getHashCode() const
- {
- return combineHash(Slang::getHashCode(type0), Slang::getHashCode(type1));
- }
- bool operator==(const TypePair& other) const
- {
- return type0 == other.type0 && type1 == other.type1;
- }
- };
Dictionary<Type*, InheritanceInfo> m_mapTypeToInheritanceInfo;
Dictionary<DeclRef<Decl>, InheritanceInfo> m_mapDeclRefToInheritanceInfo;
Dictionary<TypePair, SubtypeWitness*> m_mapTypePairToSubtypeWitness;
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index a30b5f362..2febd317e 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -2254,6 +2254,18 @@ DIAGNOSTIC(
vectorWithInvalidElementCountEncountered,
"vector has invalid element count '$0', valid values are between '$1' and '$2' inclusive")
+DIAGNOSTIC(
+ 38204,
+ Error,
+ cannotUseResourceTypeInStructuredBuffer,
+ "StructuredBuffer element type '$0' cannot contain resource or opaque handle types")
+
+DIAGNOSTIC(
+ 38205,
+ Error,
+ recursiveTypesFoundInStructuredBuffer,
+ "structured buffer element type '$0' contains recursive type references")
+
// 39xxx - Type layout and parameter binding.
DIAGNOSTIC(
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 7d8f1438d..f5b818c5d 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -1304,6 +1304,9 @@ Result linkAndOptimizeIR(
#endif
validateIRModuleIfEnabled(codeGenContext, irModule);
+ if (!validateStructuredBufferResourceTypes(irModule, sink, targetRequest))
+ return SLANG_FAIL;
+
// Many of our target languages and/or downstream compilers
// don't support `struct` types that have resource-type fields.
// In order to work around this limitation, we will rewrite the
diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp
index 156fe249f..55f3ad227 100644
--- a/source/slang/slang-ir-validate.cpp
+++ b/source/slang/slang-ir-validate.cpp
@@ -1,6 +1,7 @@
// slang-ir-validate.cpp
#include "slang-ir-validate.h"
+#include "slang-compiler.h"
#include "slang-ir-dominators.h"
#include "slang-ir-insts.h"
#include "slang-ir-util.h"
@@ -25,6 +26,31 @@ struct IRValidateContext
HashSet<IRInst*> seenInsts;
};
+// Context class for structured buffer validation
+class StructuredBufferValidationContext
+{
+public:
+ StructuredBufferValidationContext(DiagnosticSink* sink, TargetRequest* targetRequest)
+ : m_sink(sink), m_targetRequest(targetRequest), m_hasErrors(false)
+ {
+ }
+
+ bool validate(IRModule* module);
+
+private:
+ DiagnosticSink* m_sink;
+ TargetRequest* m_targetRequest;
+ bool m_hasErrors;
+
+ // Cache of types we've already checked for containing opaque handles
+ HashSet<IRType*> m_checkedTypes;
+ HashSet<IRType*> m_typesWithOpaqueHandles;
+
+ bool containsOpaqueHandleTypeCached(IRType* type);
+ bool containsOpaqueHandleTypeInternal(IRType* type, HashSet<IRType*>& visitedInCurrentCheck);
+ void validateStructuredBufferVariable(IRInst* inst);
+};
+
void validateIRInst(IRValidateContext* context, IRInst* inst);
void validate(IRValidateContext* context, bool condition, IRInst* inst, char const* message)
@@ -624,4 +650,123 @@ void validateVectorsAndMatrices(
}
}
+//
+// Structure buffer resource types
+//
+
+bool StructuredBufferValidationContext::containsOpaqueHandleTypeCached(IRType* type)
+{
+ // Check cache first
+ if (m_checkedTypes.contains(type))
+ {
+ return m_typesWithOpaqueHandles.contains(type);
+ }
+
+ // Not in cache, need to check
+ HashSet<IRType*> visitedInCurrentCheck;
+ bool result = containsOpaqueHandleTypeInternal(type, visitedInCurrentCheck);
+
+ // Cache the result
+ m_checkedTypes.add(type);
+ if (result)
+ {
+ m_typesWithOpaqueHandles.add(type);
+ }
+
+ return result;
+}
+
+bool StructuredBufferValidationContext::containsOpaqueHandleTypeInternal(
+ IRType* type,
+ HashSet<IRType*>& visitedInCurrentCheck)
+{
+ // Prevent infinite recursion in current check
+ if (!visitedInCurrentCheck.add(type))
+ return false;
+
+ // Check if the type itself is an opaque handle
+ if (isResourceType(type))
+ return true;
+
+ // Check struct types
+ if (auto structType = as<IRStructType>(type))
+ {
+ for (auto field : structType->getFields())
+ {
+ if (containsOpaqueHandleTypeInternal(field->getFieldType(), visitedInCurrentCheck))
+ return true;
+ }
+ }
+ else if (auto arrayType = as<IRArrayTypeBase>(type))
+ {
+ return containsOpaqueHandleTypeInternal(arrayType->getElementType(), visitedInCurrentCheck);
+ }
+ else if (auto ptrType = as<IRPtrTypeBase>(type))
+ {
+ return containsOpaqueHandleTypeInternal(ptrType->getValueType(), visitedInCurrentCheck);
+ }
+
+ return false;
+}
+
+void StructuredBufferValidationContext::validateStructuredBufferVariable(IRInst* inst)
+{
+ IRType* type = inst->getDataType();
+
+ // Unwrap arrays if present
+ type = unwrapArrayAndPointers(type);
+
+ // Check if this is a structured buffer type
+ auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type);
+ if (!structuredBufferType)
+ return;
+
+ // Get the element type
+ auto elementType = structuredBufferType->getElementType();
+
+ // Check if the element type contains any resource/opaque handle types
+ if (containsOpaqueHandleTypeCached(elementType))
+ {
+ m_sink->diagnose(
+ inst->sourceLoc,
+ Diagnostics::cannotUseResourceTypeInStructuredBuffer,
+ elementType);
+ m_hasErrors = true;
+ }
+}
+
+bool StructuredBufferValidationContext::validate(IRModule* module)
+{
+ // Skip validation if bindless is enabled for this target
+ if (m_targetRequest && areResourceTypesBindlessOnTarget(m_targetRequest))
+ return true;
+
+ // Iterate through all global instructions
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (auto globalVar = as<IRGlobalParam>(globalInst))
+ {
+ validateStructuredBufferVariable(globalVar);
+ }
+ else if (auto func = as<IRFunc>(globalInst))
+ {
+ for (auto param : func->getParams())
+ {
+ validateStructuredBufferVariable(param);
+ }
+ }
+ }
+
+ return !m_hasErrors;
+}
+
+bool validateStructuredBufferResourceTypes(
+ IRModule* module,
+ DiagnosticSink* sink,
+ TargetRequest* targetRequest)
+{
+ StructuredBufferValidationContext context(sink, targetRequest);
+ return context.validate(module);
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-validate.h b/source/slang/slang-ir-validate.h
index 7fc882f37..950e1b765 100644
--- a/source/slang/slang-ir-validate.h
+++ b/source/slang/slang-ir-validate.h
@@ -85,4 +85,9 @@ void validateVectorsAndMatrices(
DiagnosticSink* sink,
TargetRequest* targetRequest);
+bool validateStructuredBufferResourceTypes(
+ IRModule* module,
+ DiagnosticSink* sink,
+ TargetRequest* targetRequest);
+
} // namespace Slang
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index ab59112f3..c63ca2bec 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -8781,6 +8781,15 @@ IRType* unwrapArray(IRType* type)
return t;
}
+IRType* unwrapArrayAndPointers(IRType* type)
+{
+ if (const auto a = as<IRArrayTypeBase>(type))
+ return unwrapArrayAndPointers(a->getElementType());
+ if (const auto p = as<IRPtrTypeBase>(type))
+ return unwrapArrayAndPointers(p->getValueType());
+ return type;
+}
+
//
// IRTargetIntrinsicDecoration
//
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index d8fe51ddf..69a000b81 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -961,6 +961,7 @@ struct IRType : IRInst
};
IRType* unwrapArray(IRType* type);
+IRType* unwrapArrayAndPointers(IRType* type);
FIDDLE()
struct IRBasicType : IRType