diff options
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-emit.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-variable-scope-correction.cpp | 259 | ||||
| -rw-r--r-- | source/slang/slang-ir-variable-scope-correction.h | 35 |
3 files changed, 307 insertions, 0 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index f724b1941..fbc99b0ce 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -78,6 +78,7 @@ #include "slang-ir-pytorch-cpp-binding.h" #include "slang-ir-uniformity.h" #include "slang-ir-vk-invert-y.h" +#include "slang-ir-variable-scope-correction.h" #include "slang-legalize-types.h" #include "slang-lower-to-ir.h" #include "slang-mangle.h" @@ -1062,6 +1063,18 @@ Result linkAndOptimizeIR( #endif validateIRModuleIfEnabled(codeGenContext, irModule); + if ( (target != CodeGenTarget::SPIRV) && (target != CodeGenTarget::SPIRVAssembly) ) + { + // We need to perform a final pass to ensure that all the + // variables in the IR module have their scopes set correctly. + // + // This is a separate pass because it needs to run after + // all the other optimization passes have been performed. + + applyVariableScopeCorrection(irModule, targetRequest); + validateIRModuleIfEnabled(codeGenContext, irModule); + } + auto metadata = new ArtifactPostEmitMetadata; outLinkedIR.metadata = metadata; diff --git a/source/slang/slang-ir-variable-scope-correction.cpp b/source/slang/slang-ir-variable-scope-correction.cpp new file mode 100644 index 000000000..e8e8f70a1 --- /dev/null +++ b/source/slang/slang-ir-variable-scope-correction.cpp @@ -0,0 +1,259 @@ +#include "slang-ir-insts.h" +#include "slang-ir.h" +#include "slang-ir-clone.h" + +#include "slang-ir-dominators.h" +#include "slang-ir-variable-scope-correction.h" +#include "slang-ir-util.h" + +namespace Slang +{ + +bool isCPUTarget(TargetRequest* targetReq); +bool isCUDATarget(TargetRequest* targetReq); + +namespace { // anonymous +struct VariableScopeCorrectionContext +{ + VariableScopeCorrectionContext(IRModule* module, TargetRequest* targetReq): + m_module(module), m_builder(module), m_targetReq(targetReq) + { + } + + void processModule(); + + /// Process a function in the module + void _processFunction(IRFunc* funcInst); + void _processInstruction(IRDominatorTree* dominatorTree, IRInst* instAfterParam, + IRInst* originInst, const List<IRLoop*>& loopHeaderList, List<IRInst*>& workList); + void _processStorableInst(IRInst* insertLoc, IRInst* inst, const List<IRUse*>& outOfScopeUses); + void _processUnstorableInst(IRInst* inst, const List<IRUse*>& outOfScopeUser); + + bool _isStorableType(IRType* inst); + bool _isOutOfScopeUse(IRInst* inst, IRDominatorTree* domTree, const List<IRLoop*>& loopHeaderList); + + IRModule* m_module; + IRBuilder m_builder; + TargetRequest* m_targetReq; +}; + +void VariableScopeCorrectionContext::processModule() +{ + IRModuleInst* moduleInst = m_module->getModuleInst(); + for (IRInst* child : moduleInst->getChildren()) + { + // We want to find all of the functions, and process them + if (auto funcInst = as<IRFunc>(child)) + { + if (funcInst->getFirstBlock()) + { + _processFunction(funcInst); + } + } + } +} + +void VariableScopeCorrectionContext::_processFunction(IRFunc* funcInst) +{ + IRDominatorTree* dominatorTree = m_module->findOrCreateDominatorTree(funcInst); + List<IRInst*> workList; + Dictionary<IRBlock*, List<IRLoop*>> loopHeaderMap; + + // traverse all blocks in the function + for (auto block : funcInst->getBlocks()) + { + // Traverse all the dominators of a given block to check whether this given block is in a loop region. + // Loop region blocks are the blocks that are dominated by the loop header block + // but not dominated by the loop break block. + auto dominatorBlock = dominatorTree->getImmediateDominator(block); + List<IRLoop*> loopHeaderList; + for (; dominatorBlock; dominatorBlock = dominatorTree->getImmediateDominator(dominatorBlock)) + { + // Find if the block is loop header block + if (auto loopHeader = as<IRLoop>(dominatorBlock->getTerminator())) + { + // Get the break block of the loop and check if such block + auto breakBlock = loopHeader->getBreakBlock(); + + // Check if the current block is dominated by the break block. If so, it means that the block is in the loop region. + if (!dominatorTree->dominates(breakBlock, block)) + { + loopHeaderList.add(loopHeader); + } + } + } + loopHeaderMap.add(block, loopHeaderList); + } + + if (loopHeaderMap.getCount() == 0) + { + return; + } + + // Traverse all the instructions in function. + for (auto block : funcInst->getBlocks()) + { + if(loopHeaderMap.containsKey(block)) + { + for (auto inst : block->getChildren()) + { + List<IRInst*> instList; + // Don't process the variable declaration instruction because the code is not emitted for them unless there is a use. + if (inst->getOp() == kIROp_Var) + { + continue; + } + workList.add(inst); + } + } + } + + auto instAfterParam = funcInst->getFirstBlock()->getFirstOrdinaryInst(); + + for(auto inst = workList.begin(); inst != workList.end(); inst++) + { + if (auto loopHeaderList = loopHeaderMap.tryGetValue(getBlock(*inst))) + { + _processInstruction(dominatorTree, instAfterParam, *inst, *loopHeaderList, workList); + } + } +} + +// Check if the instruction is used outside of the loop. +// The loopHeaderList contains all the loop headers where the original instruction is defined. +// So we if the block of the user instruction is dominated by the break block of the loop header, +// it means that it was out of the loop, so it's out of the scope of the loop. +// Note the reason we use the loopHeaderList is because there could be nested loops, so we need to +// check all the loop headers from inner to outer. +bool VariableScopeCorrectionContext::_isOutOfScopeUse(IRInst * userInst, IRDominatorTree* domTree, const List<IRLoop*>& loopHeaderList) +{ + if (auto block = getBlock(userInst)) + { + // If the use site of this instruction is dominated by the break block, it means that the + // instruction is used after the break block, so we need to make that instruction available globally. + // By doing so, we record all the users of this instructions. + for(auto loopHeader : loopHeaderList) + { + auto breakBlock = loopHeader->getBreakBlock(); + if (domTree->dominates(breakBlock, block)) + { + return true; + } + } + } + return false; +} + +void VariableScopeCorrectionContext::_processInstruction(IRDominatorTree* dominatorTree, IRInst* instAfterParam, + IRInst* originInst, const List<IRLoop*>& loopHeaderList, List<IRInst*>& workList) +{ + List<IRUse*> outOfScopeUses; + for (auto use = originInst->firstUse; use; use=use->nextUse) + { + if(_isOutOfScopeUse(use->getUser(), dominatorTree, loopHeaderList)) + { + outOfScopeUses.add(use); + } + } + + if (outOfScopeUses.getCount() == 0) + return; + + if(_isStorableType(originInst->getDataType())) + { + _processStorableInst(instAfterParam, originInst, outOfScopeUses); + } + else + { + _processUnstorableInst(originInst, outOfScopeUses); + // After processing the user, we need to add operands of the instruction to the worklist + // for later processing. + for(UInt idx = 0; idx < originInst->getOperandCount(); idx++) + { + workList.add(originInst->getOperand(idx)); + } + } +} + +void VariableScopeCorrectionContext::_processStorableInst(IRInst* insertLoc, IRInst* inst, const List<IRUse*>& outOfScopeUses) +{ + auto type = inst->getDataType(); + // store instruction must have a result type + SLANG_ASSERT(type); + + // declare a new variable at the beginning of the function used to store the result of the instruction + m_builder.setInsertBefore(insertLoc); + auto dstPtr = m_builder.emitVar(type); + + // insert a store instruction after the instruction + m_builder.setInsertAfter(inst); + m_builder.emitStore(dstPtr, inst); + + // last, replace operands in the use site instruction with the new variable + // Note, because "dstPtr" is a pointer type, we have to insert a load(dstPtr) instruction before use it. + // Simply replace any operand with pointer could generate error code. + for (auto use : outOfScopeUses) + { + m_builder.setInsertBefore(use->getUser()); + auto loadInst = m_builder.emitLoad(type, dstPtr); + m_builder.replaceOperand(use, loadInst); + } +} + +void VariableScopeCorrectionContext::_processUnstorableInst(IRInst* inst, const List<IRUse*>& outOfScopeUsers) +{ + IRCloneEnv cloneEnv; + auto clonedInst = cloneInst(&cloneEnv, &m_builder, inst); + + for (auto user : outOfScopeUsers) + { + // duplicate the invisible instruction and insert it right before the use site, + // then replace the operand with the duplicated instruction + clonedInst->insertBefore(user->getUser()); + m_builder.replaceOperand(user, clonedInst); + } +} + +bool VariableScopeCorrectionContext::_isStorableType(IRType* type) +{ + if (!type) + return false; + + // C/CPP/CUDA can store any type. + if (isCPUTarget(m_targetReq) || isCUDATarget(m_targetReq)) + return true; + + if (as<IRBasicType>(type)) + return true; + + switch(type->getOp()) + { + case kIROp_VectorType: + case kIROp_MatrixType: + case kIROp_StructType: + return true; + case kIROp_ArrayType: + { + if (auto arrayType = as<IRArrayTypeBase>(type)) + return _isStorableType(arrayType->getElementType()); + else + return false; + } + case kIROp_UnsizedArrayType: + return false; + default: + return false; + } +} + +} // anonymous + +void applyVariableScopeCorrection(IRModule* module, TargetRequest* targetReq) +{ + VariableScopeCorrectionContext context(module, targetReq); + + context.processModule(); +} + +} // namespace Slang + diff --git a/source/slang/slang-ir-variable-scope-correction.h b/source/slang/slang-ir-variable-scope-correction.h new file mode 100644 index 000000000..5f958f9d0 --- /dev/null +++ b/source/slang/slang-ir-variable-scope-correction.h @@ -0,0 +1,35 @@ +// slang-ir-variable-scope-correction.h +#ifndef SLANG_IR_VARIABLE_SCOPE_CORRECTION_H +#define SLANG_IR_VARIABLE_SCOPE_CORRECTION_H + +namespace Slang +{ + +struct IRModule; + +/// This pass correct the scope of variables in loop regions +/// +/// In the IR optimization pass, we turn all the loop to do-while loop form. +/// But in the do-while loop form, the loop body block is dominating the +/// blocks after the loop break block. E.g. +/// +/// do { +/// A +/// } while (cond); +/// B +/// +/// In the above example, the block A is dominating block B. This assumption +/// is fine for SPIRV and IR code, however, it's incorrect for all the other +/// language targets (e.g. c/c++/cuda/glsl/hlsl) because the instructions defined +/// in the block A are not visible from block B. Therefore, when translating to +/// other textual language, there could be issue for the variables scope. +/// +/// To fix this issue, we first detect the instructions that are defined +/// inside the loop block (block A), then check if these instructions are used after +/// the break block (block B). If so, we duplicate these instructions right before +/// their users such that we can make those instructions available globally. +void applyVariableScopeCorrection(IRModule* module, TargetRequest* targetReq); + +} + +#endif // SLANG_IR_VARIABLE_SCOPE_CORRECTION_H |
