diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-base.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 105 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 15 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-validate.cpp | 145 | ||||
| -rw-r--r-- | source/slang/slang-ir-validate.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 1 |
9 files changed, 296 insertions, 13 deletions
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index dbbfe6b92..828bbdb5c 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -598,6 +598,20 @@ protected: ASTBuilder* m_astBuilderForReflection; }; +struct TypePair +{ + Type* type0; + Type* type1; + HashCode getHashCode() const + { + return combineHash(Slang::getHashCode(type0), Slang::getHashCode(type1)); + } + bool operator==(const TypePair& other) const + { + return type0 == other.type0 && type1 == other.type1; + } +}; + template<typename T> SLANG_FORCE_INLINE T* as(Type* obj) { diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index e59cf6ad5..4711eaddd 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2698,6 +2698,8 @@ static Expr* constructDefaultInitExprForType(SemanticsVisitor* visitor, VarDeclB } } +void validateStructuredBufferElementType(SemanticsVisitor* visitor, VarDeclBase* varDecl); + void SemanticsDeclBodyVisitor::checkVarDeclCommon(VarDeclBase* varDecl) { DiagnoseIsAllowedInitExpr(varDecl, getSink()); @@ -2892,6 +2894,9 @@ void SemanticsDeclBodyVisitor::checkVarDeclCommon(VarDeclBase* varDecl) } } } + + validateStructuredBufferElementType(this, varDecl); + bool isGlobalOrLocalVar = !isGlobalShaderParameter(varDecl) && !as<ParamDecl>(varDecl) && (!parentDecl || isEffectivelyStatic(varDecl)); if (isGlobalOrLocalVar) @@ -15018,6 +15023,106 @@ bool isOpaqueHandleType(Type* type) return false; } +bool containsRecursiveTypeImpl(SemanticsVisitor* visitor, Type* type, HashSet<Decl*>& currentPath) +{ + // Skip modified types (const, etc.) + while (auto modifiedType = as<ModifiedType>(type)) + type = modifiedType->getBase(); + + // Check if this is a StructuredBuffer type and look inside it + if (auto structuredBufferType = as<HLSLStructuredBufferTypeBase>(type)) + { + return containsRecursiveTypeImpl( + visitor, + structuredBufferType->getElementType(), + currentPath); + } + + // Check if this is an array type and look inside it + if (auto arrayType = as<ArrayExpressionType>(type)) + { + return containsRecursiveTypeImpl(visitor, arrayType->getElementType(), currentPath); + } + + if (auto declRefType = as<DeclRefType>(type)) + { + auto typeDecl = declRefType->getDeclRef().getDecl(); + + // Check global cache first - if we've already fully analyzed this type, use that result + auto shared = visitor->getShared(); + if (auto cachedResult = shared->m_typeContainsRecursionCache.tryGetValue(typeDecl)) + { + return *cachedResult; + } + + // If we're currently exploring this type, we found a cycle! + if (currentPath.contains(typeDecl)) + { + return true; + } + + // Add to current exploration path + currentPath.add(typeDecl); + + bool hasRecursion = false; + + // Check members if it's an aggregate type + if (auto aggTypeDecl = declRefType->getDeclRef().as<AggTypeDecl>()) + { + for (auto member : aggTypeDecl.getDecl()->getMembersOfType<VarDeclBase>()) + { + if (isEffectivelyStatic(member)) + continue; + + if (containsRecursiveTypeImpl(visitor, member->getType(), currentPath)) + { + hasRecursion = true; + break; + } + } + } + + // Remove from current exploration path + currentPath.remove(typeDecl); + + // Cache the result globally + shared->m_typeContainsRecursionCache[typeDecl] = hasRecursion; + + return hasRecursion; + } + + return false; +} + +bool containsRecursiveType(SemanticsVisitor* visitor, Type* type) +{ + HashSet<Decl*> currentPath; + return containsRecursiveTypeImpl(visitor, type, currentPath); +} + +void validateStructuredBufferElementType(SemanticsVisitor* visitor, VarDeclBase* varDecl) +{ + auto type = unwrapArrayType(varDecl->getType()); + + // Check if this is a StructuredBuffer type + auto structuredBufferType = as<HLSLStructuredBufferTypeBase>(type); + + if (!structuredBufferType) + return; + + // Get the element type + auto elementType = structuredBufferType->getElementType(); + + // Check if the element type contains recursive references + if (containsRecursiveType(visitor, elementType)) + { + visitor->getSink()->diagnose( + varDecl->loc, + Diagnostics::recursiveTypesFoundInStructuredBuffer, + elementType); + } +} + void diagnoseMissingCapabilityProvenance( CompilerOptionSet& optionSet, DiagnosticSink* sink, diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index e6c66ddd3..d82bf4427 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -699,6 +699,8 @@ struct SharedSemanticsContext : public RefObject GLSLBindingOffsetTracker m_glslBindingOffsetTracker; + Dictionary<Decl*, bool> m_typeContainsRecursionCache; + public: SharedSemanticsContext( Linkage* linkage, @@ -919,19 +921,6 @@ private: FacetList baseFacets, FacetList::Builder& ioMergedFacets); - struct TypePair - { - Type* type0; - Type* type1; - HashCode getHashCode() const - { - return combineHash(Slang::getHashCode(type0), Slang::getHashCode(type1)); - } - bool operator==(const TypePair& other) const - { - return type0 == other.type0 && type1 == other.type1; - } - }; Dictionary<Type*, InheritanceInfo> m_mapTypeToInheritanceInfo; Dictionary<DeclRef<Decl>, InheritanceInfo> m_mapDeclRefToInheritanceInfo; Dictionary<TypePair, SubtypeWitness*> m_mapTypePairToSubtypeWitness; diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index a30b5f362..2febd317e 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -2254,6 +2254,18 @@ DIAGNOSTIC( vectorWithInvalidElementCountEncountered, "vector has invalid element count '$0', valid values are between '$1' and '$2' inclusive") +DIAGNOSTIC( + 38204, + Error, + cannotUseResourceTypeInStructuredBuffer, + "StructuredBuffer element type '$0' cannot contain resource or opaque handle types") + +DIAGNOSTIC( + 38205, + Error, + recursiveTypesFoundInStructuredBuffer, + "structured buffer element type '$0' contains recursive type references") + // 39xxx - Type layout and parameter binding. DIAGNOSTIC( diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 7d8f1438d..f5b818c5d 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1304,6 +1304,9 @@ Result linkAndOptimizeIR( #endif validateIRModuleIfEnabled(codeGenContext, irModule); + if (!validateStructuredBufferResourceTypes(irModule, sink, targetRequest)) + return SLANG_FAIL; + // Many of our target languages and/or downstream compilers // don't support `struct` types that have resource-type fields. // In order to work around this limitation, we will rewrite the diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp index 156fe249f..55f3ad227 100644 --- a/source/slang/slang-ir-validate.cpp +++ b/source/slang/slang-ir-validate.cpp @@ -1,6 +1,7 @@ // slang-ir-validate.cpp #include "slang-ir-validate.h" +#include "slang-compiler.h" #include "slang-ir-dominators.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" @@ -25,6 +26,31 @@ struct IRValidateContext HashSet<IRInst*> seenInsts; }; +// Context class for structured buffer validation +class StructuredBufferValidationContext +{ +public: + StructuredBufferValidationContext(DiagnosticSink* sink, TargetRequest* targetRequest) + : m_sink(sink), m_targetRequest(targetRequest), m_hasErrors(false) + { + } + + bool validate(IRModule* module); + +private: + DiagnosticSink* m_sink; + TargetRequest* m_targetRequest; + bool m_hasErrors; + + // Cache of types we've already checked for containing opaque handles + HashSet<IRType*> m_checkedTypes; + HashSet<IRType*> m_typesWithOpaqueHandles; + + bool containsOpaqueHandleTypeCached(IRType* type); + bool containsOpaqueHandleTypeInternal(IRType* type, HashSet<IRType*>& visitedInCurrentCheck); + void validateStructuredBufferVariable(IRInst* inst); +}; + void validateIRInst(IRValidateContext* context, IRInst* inst); void validate(IRValidateContext* context, bool condition, IRInst* inst, char const* message) @@ -624,4 +650,123 @@ void validateVectorsAndMatrices( } } +// +// Structure buffer resource types +// + +bool StructuredBufferValidationContext::containsOpaqueHandleTypeCached(IRType* type) +{ + // Check cache first + if (m_checkedTypes.contains(type)) + { + return m_typesWithOpaqueHandles.contains(type); + } + + // Not in cache, need to check + HashSet<IRType*> visitedInCurrentCheck; + bool result = containsOpaqueHandleTypeInternal(type, visitedInCurrentCheck); + + // Cache the result + m_checkedTypes.add(type); + if (result) + { + m_typesWithOpaqueHandles.add(type); + } + + return result; +} + +bool StructuredBufferValidationContext::containsOpaqueHandleTypeInternal( + IRType* type, + HashSet<IRType*>& visitedInCurrentCheck) +{ + // Prevent infinite recursion in current check + if (!visitedInCurrentCheck.add(type)) + return false; + + // Check if the type itself is an opaque handle + if (isResourceType(type)) + return true; + + // Check struct types + if (auto structType = as<IRStructType>(type)) + { + for (auto field : structType->getFields()) + { + if (containsOpaqueHandleTypeInternal(field->getFieldType(), visitedInCurrentCheck)) + return true; + } + } + else if (auto arrayType = as<IRArrayTypeBase>(type)) + { + return containsOpaqueHandleTypeInternal(arrayType->getElementType(), visitedInCurrentCheck); + } + else if (auto ptrType = as<IRPtrTypeBase>(type)) + { + return containsOpaqueHandleTypeInternal(ptrType->getValueType(), visitedInCurrentCheck); + } + + return false; +} + +void StructuredBufferValidationContext::validateStructuredBufferVariable(IRInst* inst) +{ + IRType* type = inst->getDataType(); + + // Unwrap arrays if present + type = unwrapArrayAndPointers(type); + + // Check if this is a structured buffer type + auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type); + if (!structuredBufferType) + return; + + // Get the element type + auto elementType = structuredBufferType->getElementType(); + + // Check if the element type contains any resource/opaque handle types + if (containsOpaqueHandleTypeCached(elementType)) + { + m_sink->diagnose( + inst->sourceLoc, + Diagnostics::cannotUseResourceTypeInStructuredBuffer, + elementType); + m_hasErrors = true; + } +} + +bool StructuredBufferValidationContext::validate(IRModule* module) +{ + // Skip validation if bindless is enabled for this target + if (m_targetRequest && areResourceTypesBindlessOnTarget(m_targetRequest)) + return true; + + // Iterate through all global instructions + for (auto globalInst : module->getGlobalInsts()) + { + if (auto globalVar = as<IRGlobalParam>(globalInst)) + { + validateStructuredBufferVariable(globalVar); + } + else if (auto func = as<IRFunc>(globalInst)) + { + for (auto param : func->getParams()) + { + validateStructuredBufferVariable(param); + } + } + } + + return !m_hasErrors; +} + +bool validateStructuredBufferResourceTypes( + IRModule* module, + DiagnosticSink* sink, + TargetRequest* targetRequest) +{ + StructuredBufferValidationContext context(sink, targetRequest); + return context.validate(module); +} + } // namespace Slang diff --git a/source/slang/slang-ir-validate.h b/source/slang/slang-ir-validate.h index 7fc882f37..950e1b765 100644 --- a/source/slang/slang-ir-validate.h +++ b/source/slang/slang-ir-validate.h @@ -85,4 +85,9 @@ void validateVectorsAndMatrices( DiagnosticSink* sink, TargetRequest* targetRequest); +bool validateStructuredBufferResourceTypes( + IRModule* module, + DiagnosticSink* sink, + TargetRequest* targetRequest); + } // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index ab59112f3..c63ca2bec 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -8781,6 +8781,15 @@ IRType* unwrapArray(IRType* type) return t; } +IRType* unwrapArrayAndPointers(IRType* type) +{ + if (const auto a = as<IRArrayTypeBase>(type)) + return unwrapArrayAndPointers(a->getElementType()); + if (const auto p = as<IRPtrTypeBase>(type)) + return unwrapArrayAndPointers(p->getValueType()); + return type; +} + // // IRTargetIntrinsicDecoration // diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index d8fe51ddf..69a000b81 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -961,6 +961,7 @@ struct IRType : IRInst }; IRType* unwrapArray(IRType* type); +IRType* unwrapArrayAndPointers(IRType* type); FIDDLE() struct IRBasicType : IRType |
