// slang-ir-validate.cpp #include "slang-ir-validate.h" #include "slang-compiler.h" #include "slang-ir-dominators.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" #include "slang-ir.h" namespace Slang { struct IRValidateContext { // The IR module we are validating. IRModule* module; RefPtr domTree; // A diagnostic sink to send errors to if anything is invalid. DiagnosticSink* sink; DiagnosticSink* getSink() { return sink; } // A set of instructions we've seen, to help confirm that // values are defined before they are used in a given block. HashSet seenInsts; }; // Context class for structured buffer validation class StructuredBufferValidationContext { public: StructuredBufferValidationContext(DiagnosticSink* sink, TargetRequest* targetRequest) : m_sink(sink), m_targetRequest(targetRequest), m_hasErrors(false) { } bool validate(IRModule* module); private: DiagnosticSink* m_sink; TargetRequest* m_targetRequest; bool m_hasErrors; // Cache of types we've already checked for containing opaque handles HashSet m_checkedTypes; HashSet m_typesWithOpaqueHandles; bool containsOpaqueHandleTypeCached(IRType* type); bool containsOpaqueHandleTypeInternal(IRType* type, HashSet& visitedInCurrentCheck); void validateStructuredBufferVariable(IRInst* inst); }; void validateIRInst(IRValidateContext* context, IRInst* inst); void validate(IRValidateContext* context, bool condition, IRInst* inst, char const* message) { if (!condition) { if (context) { context->getSink()->diagnose(inst, Diagnostics::irValidationFailed, message); } else { SLANG_ASSERT_FAILURE("IR validation failed"); } } } void validateIRInstChildren(IRValidateContext* context, IRInst* parent) { // We want to check that child instructions are correctly // ordered so that decorations come first, then any parameters, // and then any ordinary instructions. // // We will track what we have seen so far with a simple state // machine, which in valid IR should proceed monitonically // up through the following states: // enum State { kState_Initial = 0, kState_AfterDecoration, kState_AfterParam, kState_AfterOrdinary, }; State state = kState_Initial; IRInst* prevChild = nullptr; bool hasSeenTerminatorInst = false; for (auto child : parent->getDecorationsAndChildren()) { // We need to check the integrity of the parent/next/prev links of // all of our instructions validate(context, child->parent == parent, child, "parent link"); validate(context, child->prev == prevChild, child, "next/prev link"); // Recursively validate the instruction itself. validateIRInst(context, child); if (as(child)) { validate( context, state <= kState_AfterDecoration, child, "decorations must come before other child instructions"); state = kState_AfterDecoration; } else if (as(child)) { validate( context, state <= kState_AfterParam, child, "parameters must come before ordinary instructions"); state = kState_AfterParam; } else { state = kState_AfterOrdinary; } // Do some extra validation around terminator instructions: // // * The last instruction of a block should always be a terminator // * No other instruction should be a terminator // if (as(parent) && (child == parent->getLastDecorationOrChild())) { validate( context, as(child) != nullptr, child, "last instruction in block must be terminator"); } else { validate( context, !as(child), child, "terminator must be last instruction in a block"); } if (as(child)) { validate( context, !hasSeenTerminatorInst, child, "block must not contain more than one terminator"); hasSeenTerminatorInst = true; } prevChild = child; } } void validateIRInstOperand(IRValidateContext* context, IRInst* inst, IRUse* operandUse) { // The `IRUse` for the operand had better have `inst` as its user. validate(context, operandUse->getUser() == inst, inst, "operand user"); // The value we are using needs to fit into one of a few cases. // // * If the parent of `inst` and of `operand` is the same block, then // we require that `operand` is defined before `inst` // // * If the parents of `inst` and `operand` are both blocks in the // same functin, then the block defining `operand` must dominate // the block defining `inst`. // // * Otherwise, we simply require that the parent of `operand` be // an ancestor (transitive parent) of `inst`. auto instParent = inst->getParent(); auto operandValue = operandUse->get(); if (!operandValue) { // A null operand should almost always be an error, but // we currently have a few cases where this arises. // // TODO: plug the leaks. return; } auto operandParent = operandValue->getParent(); auto instParentBlock = getBlock(inst); if (instParentBlock) { if (auto operandParentBlock = as(operandParent)) { if (instParentBlock == operandParentBlock) { // If `operandValue` precedes `inst`, then we should // have already seen it, because we scan parent instructions // in order. if (context) { validate( context, context->seenInsts.contains(operandValue), inst, "def must come before use in same block"); } return; } auto instFunc = instParentBlock->getParent(); auto operandFunc = operandParentBlock->getParent(); if (instFunc == operandFunc) { // The two instructions are defined in different blocks of // the same function (or another value with code). We need // to validate that `operandParentBlock` dominates `instParentBlock`. // if (context && context->domTree) { validate( context, context->domTree->dominates(operandParentBlock, instParentBlock), inst, "def must dominate use"); } return; } } } // If the special cases above did not trigger, then either the two values // are nested in the same parent, but that parent isn't a block, or they // are nested in distinct parents, and those parents aren't both children // of a function. // // In either case, we need to enforce that the parent of `operand` needs // to be an ancestor of `inst`. // for (auto pp = instParent; pp; pp = pp->getParent()) { if (pp == operandParent) return; } // We allow out-of-order def-use in global scope. bool allInGlobalScope = inst->getParent() && inst->getParent()->getOp() == kIROp_ModuleInst; if (allInGlobalScope) { for (UInt i = 0; i < inst->getOperandCount(); i++) { auto op = inst->getOperand(i); if (!op) continue; if (!op->getParent()) continue; if (op->getParent()->getOp() != kIROp_ModuleInst) { allInGlobalScope = false; break; } } } if (allInGlobalScope) return; // Allow exceptions. switch (inst->getOp()) { case kIROp_DifferentiableTypeDictionaryItem: case kIROp_DebugScope: return; } // // We failed to find `operandParent` while walking the ancestors of `inst`, // so something had gone wrong. validate(context, false, inst, "def must be ancestor of use"); } void validateIRInstOperands(IRValidateContext* context, IRInst* inst) { if (inst->getFullType()) validateIRInstOperand(context, inst, &inst->typeUse); // Avoid validating decoration operands // since they don't have to conform to inst visibility // constraints. // if (as(inst)) return; UInt operandCount = inst->getOperandCount(); for (UInt ii = 0; ii < operandCount; ++ii) { validateIRInstOperand(context, inst, inst->getOperands() + ii); } } static thread_local bool _enableIRValidationAtInsert = false; // RAII class implementation for exception-safe IR validation state management IRValidationScope::IRValidationScope(bool enableValidation) : m_previousState(_enableIRValidationAtInsert) { _enableIRValidationAtInsert = enableValidation; } IRValidationScope::~IRValidationScope() { _enableIRValidationAtInsert = m_previousState; } void validateIRInstOperands(IRInst* inst) { if (!_enableIRValidationAtInsert) return; switch (inst->getOp()) { case kIROp_Loop: case kIROp_IfElse: case kIROp_UnconditionalBranch: case kIROp_ConditionalBranch: case kIROp_Switch: return; default: break; } validateIRInstOperands(nullptr, inst); } void validateCodeBody(IRValidateContext* context, IRGlobalValueWithCode* code) { HashSet blocks; for (auto block : code->getBlocks()) blocks.add(block); auto validateBranchTarget = [&](IRInst* inst, IRBlock* target) { validate( context, blocks.contains(target), inst, "branch inst must have a valid target block that is defined within the same " "scope."); }; for (auto block : code->getBlocks()) { auto terminator = block->getTerminator(); validate(context, terminator, block, "block must have valid terminator inst."); switch (terminator->getOp()) { case kIROp_ConditionalBranch: validateBranchTarget(terminator, as(terminator)->getTrueBlock()); validateBranchTarget(terminator, as(terminator)->getFalseBlock()); break; case kIROp_Loop: case kIROp_UnconditionalBranch: validateBranchTarget( terminator, as(terminator)->getTargetBlock()); break; case kIROp_Switch: { auto switchInst = as(terminator); for (UInt i = 0; i < switchInst->getCaseCount(); i++) { validateBranchTarget(switchInst, switchInst->getCaseLabel(i)); } validateBranchTarget(switchInst, switchInst->getDefaultLabel()); validateBranchTarget(switchInst, switchInst->getBreakLabel()); } } } } void validateIRInst(IRValidateContext* context, IRInst* inst) { // Validate that any operands of the instruction are used appropriately validateIRInstOperands(context, inst); context->seenInsts.add(inst); if (auto code = as(inst)) { context->domTree = computeDominatorTree(code); validateCodeBody(context, code); } // If `inst` is itself a parent instruction, then we need to recursively // validate its children. validateIRInstChildren(context, inst); if (as(inst)) context->domTree = nullptr; } void validateIRInst(IRInst* inst) { IRValidateContext contextStorage; IRValidateContext* context = &contextStorage; DiagnosticSink sink; context->module = inst->getModule(); context->sink = &sink; if (auto func = as(inst)) context->domTree = computeDominatorTree(func); validateIRInst(context, inst); } void validateIRModule(IRModule* module, DiagnosticSink* sink) { IRValidateContext contextStorage; IRValidateContext* context = &contextStorage; context->module = module; context->sink = sink; auto moduleInst = module->getModuleInst(); validate(context, moduleInst != nullptr, moduleInst, "module instruction"); validate(context, moduleInst->parent == nullptr, moduleInst, "module instruction parent"); validate(context, moduleInst->prev == nullptr, moduleInst, "module instruction prev"); validate(context, moduleInst->next == nullptr, moduleInst, "module instruction next"); validateIRInst(context, moduleInst); } void validateIRModuleIfEnabled(CompileRequestBase* compileRequest, IRModule* module) { if (!compileRequest->getLinkage()->m_optionSet.getBoolOption(CompilerOptionName::ValidateIr)) return; auto sink = compileRequest->getSink(); validateIRModule(module, sink); } void validateIRModuleIfEnabled(CodeGenContext* codeGenContext, IRModule* module) { if (!codeGenContext->shouldValidateIR()) return; auto sink = codeGenContext->getSink(); validateIRModule(module, sink); } // Returns whether 'dst' is a valid destination for atomic operations, meaning // it leads either to 'groupshared' or 'device buffer' memory. static bool isValidAtomicDest(bool skipFuncParamValidation, IRInst* dst) { bool isGroupShared = as(dst->getRate()); if (isGroupShared) return true; if (as(dst)) return true; if (as(dst)) return true; if (auto ptrType = as(dst->getDataType())) { switch (ptrType->getAddressSpace()) { case AddressSpace::Global: case AddressSpace::GroupShared: case AddressSpace::StorageBuffer: case AddressSpace::UserPointer: return true; default: break; } } if (as(dst)) { switch (dst->getDataType()->getOp()) { case kIROp_GLSLShaderStorageBufferType: case kIROp_TextureType: return true; default: return false; } } if (auto param = as(dst)) { auto paramType = param->getDataType(); if (auto outType = as(paramType)) { if (outType->getAddressSpace() == AddressSpace::GroupShared) { return true; } else if (skipFuncParamValidation) { // We haven't actually verified that this is a valid atomic operation destination, // but the callee wants to skip this specific validation. return true; } } } if (auto getElementPtr = as(dst)) return isValidAtomicDest(skipFuncParamValidation, getElementPtr->getBase()); if (auto getOffsetPtr = as(dst)) return isValidAtomicDest(skipFuncParamValidation, getOffsetPtr->getBase()); if (auto fieldAddress = as(dst)) return isValidAtomicDest(skipFuncParamValidation, fieldAddress->getBase()); return false; } void validateAtomicOperations(bool skipFuncParamValidation, DiagnosticSink* sink, IRInst* inst) { switch (inst->getOp()) { case kIROp_AtomicLoad: case kIROp_AtomicStore: case kIROp_AtomicExchange: case kIROp_AtomicCompareExchange: case kIROp_AtomicAdd: case kIROp_AtomicSub: case kIROp_AtomicAnd: case kIROp_AtomicOr: case kIROp_AtomicXor: case kIROp_AtomicMin: case kIROp_AtomicMax: case kIROp_AtomicInc: case kIROp_AtomicDec: { IRInst* destinationPtr = inst->getOperand(0); if (!isValidAtomicDest(skipFuncParamValidation, destinationPtr)) sink->diagnose(inst->sourceLoc, Diagnostics::invalidAtomicDestinationPointer); } break; default: break; } for (auto child : inst->getModifiableChildren()) { validateAtomicOperations(skipFuncParamValidation, sink, child); } } static void validateVectorOrMatrixElementType( DiagnosticSink* sink, SourceLoc sourceLoc, IRType* elementType, uint32_t allowedWidths, const DiagnosticInfo& disallowedElementTypeEncountered) { if (!isFloatingType(elementType)) { if (isIntegralType(elementType)) { IntInfo info = getIntTypeInfo(elementType); if (allowedWidths == 0U) { sink->diagnose(sourceLoc, disallowedElementTypeEncountered, elementType); } else { bool widthAllowed = false; SLANG_ASSERT((allowedWidths & ~(0xfU << 3)) == 0U); for (uint32_t p = 3U; p <= 6U; p++) { uint32_t width = 1U << p; if (!(allowedWidths & width)) continue; widthAllowed = widthAllowed || (info.width == width); } if (!widthAllowed) { sink->diagnose(sourceLoc, disallowedElementTypeEncountered, elementType); } } } else if (!as(elementType)) { sink->diagnose(sourceLoc, disallowedElementTypeEncountered, elementType); } } } static void validateVectorElementCount(DiagnosticSink* sink, IRVectorType* vectorType) { const auto elementCount = as(vectorType->getElementCount())->getValue(); // 1-vectors are supported and are legalized/transformed properly when targetting unsupported // backends. const IRIntegerValue minCount = 1; const IRIntegerValue maxCount = 4; if ((elementCount < minCount) || (elementCount > maxCount)) { sink->diagnose( vectorType->sourceLoc, Diagnostics::vectorWithInvalidElementCountEncountered, elementCount, "1", maxCount); } } void validateVectorsAndMatrices( IRModule* module, DiagnosticSink* sink, TargetRequest* targetRequest) { for (auto globalInst : module->getGlobalInsts()) { if (auto matrixType = as(globalInst)) { // Matrices with row/col dimension 1 are only well-supported on D3D targets if (!isD3DTarget(targetRequest)) { // Verify that neither row nor col count is 1 auto colCount = as(matrixType->getColumnCount()); auto rowCount = as(matrixType->getRowCount()); if ((rowCount && (rowCount->getValue() == 1)) || (colCount && (colCount->getValue() == 1))) { sink->diagnose(matrixType->sourceLoc, Diagnostics::matrixColumnOrRowCountIsOne); } } // Matrix element type validation removed to allow integer/bool matrices // which will be lowered to arrays of vectors on targets that don't support them // natively } else if (auto vectorType = as(globalInst)) { // Verify that the element type is a floating point type, or an allowed integral type auto elementType = vectorType->getElementType(); uint32_t allowedWidths = 0U; if (isWGPUTarget(targetRequest)) allowedWidths = 32U; else allowedWidths = 8U | 16U | 32U | 64U; validateVectorOrMatrixElementType( sink, vectorType->sourceLoc, elementType, allowedWidths, Diagnostics::vectorWithDisallowedElementTypeEncountered); validateVectorElementCount(sink, vectorType); } } } // // Structure buffer resource types // bool StructuredBufferValidationContext::containsOpaqueHandleTypeCached(IRType* type) { // Check cache first if (m_checkedTypes.contains(type)) { return m_typesWithOpaqueHandles.contains(type); } // Not in cache, need to check HashSet visitedInCurrentCheck; bool result = containsOpaqueHandleTypeInternal(type, visitedInCurrentCheck); // Cache the result m_checkedTypes.add(type); if (result) { m_typesWithOpaqueHandles.add(type); } return result; } bool StructuredBufferValidationContext::containsOpaqueHandleTypeInternal( IRType* type, HashSet& visitedInCurrentCheck) { // Prevent infinite recursion in current check if (!visitedInCurrentCheck.add(type)) return false; // Check if the type itself is an opaque handle if (isResourceType(type)) return true; // Check struct types if (auto structType = as(type)) { for (auto field : structType->getFields()) { if (containsOpaqueHandleTypeInternal(field->getFieldType(), visitedInCurrentCheck)) return true; } } else if (auto arrayType = as(type)) { return containsOpaqueHandleTypeInternal(arrayType->getElementType(), visitedInCurrentCheck); } else if (auto ptrType = as(type)) { return containsOpaqueHandleTypeInternal(ptrType->getValueType(), visitedInCurrentCheck); } return false; } void StructuredBufferValidationContext::validateStructuredBufferVariable(IRInst* inst) { IRType* type = inst->getDataType(); // Unwrap arrays if present type = unwrapArrayAndPointers(type); // Check if this is a structured buffer type auto structuredBufferType = as(type); if (!structuredBufferType) return; // Get the element type auto elementType = structuredBufferType->getElementType(); // Check if the element type contains any resource/opaque handle types if (containsOpaqueHandleTypeCached(elementType)) { m_sink->diagnose( inst->sourceLoc, Diagnostics::cannotUseResourceTypeInStructuredBuffer, elementType); m_hasErrors = true; } } bool StructuredBufferValidationContext::validate(IRModule* module) { // Skip validation if bindless is enabled for this target if (m_targetRequest && areResourceTypesBindlessOnTarget(m_targetRequest)) return true; // Iterate through all global instructions for (auto globalInst : module->getGlobalInsts()) { if (auto globalVar = as(globalInst)) { validateStructuredBufferVariable(globalVar); } else if (auto func = as(globalInst)) { for (auto param : func->getParams()) { validateStructuredBufferVariable(param); } } } return !m_hasErrors; } bool validateStructuredBufferResourceTypes( IRModule* module, DiagnosticSink* sink, TargetRequest* targetRequest) { StructuredBufferValidationContext context(sink, targetRequest); return context.validate(module); } } // namespace Slang