summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-emit.cpp13
-rw-r--r--source/slang/slang-ir-variable-scope-correction.cpp259
-rw-r--r--source/slang/slang-ir-variable-scope-correction.h35
3 files changed, 307 insertions, 0 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index f724b1941..fbc99b0ce 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -78,6 +78,7 @@
#include "slang-ir-pytorch-cpp-binding.h"
#include "slang-ir-uniformity.h"
#include "slang-ir-vk-invert-y.h"
+#include "slang-ir-variable-scope-correction.h"
#include "slang-legalize-types.h"
#include "slang-lower-to-ir.h"
#include "slang-mangle.h"
@@ -1062,6 +1063,18 @@ Result linkAndOptimizeIR(
#endif
validateIRModuleIfEnabled(codeGenContext, irModule);
+ if ( (target != CodeGenTarget::SPIRV) && (target != CodeGenTarget::SPIRVAssembly) )
+ {
+ // We need to perform a final pass to ensure that all the
+ // variables in the IR module have their scopes set correctly.
+ //
+ // This is a separate pass because it needs to run after
+ // all the other optimization passes have been performed.
+
+ applyVariableScopeCorrection(irModule, targetRequest);
+ validateIRModuleIfEnabled(codeGenContext, irModule);
+ }
+
auto metadata = new ArtifactPostEmitMetadata;
outLinkedIR.metadata = metadata;
diff --git a/source/slang/slang-ir-variable-scope-correction.cpp b/source/slang/slang-ir-variable-scope-correction.cpp
new file mode 100644
index 000000000..e8e8f70a1
--- /dev/null
+++ b/source/slang/slang-ir-variable-scope-correction.cpp
@@ -0,0 +1,259 @@
+#include "slang-ir-insts.h"
+#include "slang-ir.h"
+#include "slang-ir-clone.h"
+
+#include "slang-ir-dominators.h"
+#include "slang-ir-variable-scope-correction.h"
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+
+bool isCPUTarget(TargetRequest* targetReq);
+bool isCUDATarget(TargetRequest* targetReq);
+
+namespace { // anonymous
+struct VariableScopeCorrectionContext
+{
+ VariableScopeCorrectionContext(IRModule* module, TargetRequest* targetReq):
+ m_module(module), m_builder(module), m_targetReq(targetReq)
+ {
+ }
+
+ void processModule();
+
+ /// Process a function in the module
+ void _processFunction(IRFunc* funcInst);
+ void _processInstruction(IRDominatorTree* dominatorTree, IRInst* instAfterParam,
+ IRInst* originInst, const List<IRLoop*>& loopHeaderList, List<IRInst*>& workList);
+ void _processStorableInst(IRInst* insertLoc, IRInst* inst, const List<IRUse*>& outOfScopeUses);
+ void _processUnstorableInst(IRInst* inst, const List<IRUse*>& outOfScopeUser);
+
+ bool _isStorableType(IRType* inst);
+ bool _isOutOfScopeUse(IRInst* inst, IRDominatorTree* domTree, const List<IRLoop*>& loopHeaderList);
+
+ IRModule* m_module;
+ IRBuilder m_builder;
+ TargetRequest* m_targetReq;
+};
+
+void VariableScopeCorrectionContext::processModule()
+{
+ IRModuleInst* moduleInst = m_module->getModuleInst();
+ for (IRInst* child : moduleInst->getChildren())
+ {
+ // We want to find all of the functions, and process them
+ if (auto funcInst = as<IRFunc>(child))
+ {
+ if (funcInst->getFirstBlock())
+ {
+ _processFunction(funcInst);
+ }
+ }
+ }
+}
+
+void VariableScopeCorrectionContext::_processFunction(IRFunc* funcInst)
+{
+ IRDominatorTree* dominatorTree = m_module->findOrCreateDominatorTree(funcInst);
+ List<IRInst*> workList;
+ Dictionary<IRBlock*, List<IRLoop*>> loopHeaderMap;
+
+ // traverse all blocks in the function
+ for (auto block : funcInst->getBlocks())
+ {
+ // Traverse all the dominators of a given block to check whether this given block is in a loop region.
+ // Loop region blocks are the blocks that are dominated by the loop header block
+ // but not dominated by the loop break block.
+ auto dominatorBlock = dominatorTree->getImmediateDominator(block);
+ List<IRLoop*> loopHeaderList;
+ for (; dominatorBlock; dominatorBlock = dominatorTree->getImmediateDominator(dominatorBlock))
+ {
+ // Find if the block is loop header block
+ if (auto loopHeader = as<IRLoop>(dominatorBlock->getTerminator()))
+ {
+ // Get the break block of the loop and check if such block
+ auto breakBlock = loopHeader->getBreakBlock();
+
+ // Check if the current block is dominated by the break block. If so, it means that the block is in the loop region.
+ if (!dominatorTree->dominates(breakBlock, block))
+ {
+ loopHeaderList.add(loopHeader);
+ }
+ }
+ }
+ loopHeaderMap.add(block, loopHeaderList);
+ }
+
+ if (loopHeaderMap.getCount() == 0)
+ {
+ return;
+ }
+
+ // Traverse all the instructions in function.
+ for (auto block : funcInst->getBlocks())
+ {
+ if(loopHeaderMap.containsKey(block))
+ {
+ for (auto inst : block->getChildren())
+ {
+ List<IRInst*> instList;
+ // Don't process the variable declaration instruction because the code is not emitted for them unless there is a use.
+ if (inst->getOp() == kIROp_Var)
+ {
+ continue;
+ }
+ workList.add(inst);
+ }
+ }
+ }
+
+ auto instAfterParam = funcInst->getFirstBlock()->getFirstOrdinaryInst();
+
+ for(auto inst = workList.begin(); inst != workList.end(); inst++)
+ {
+ if (auto loopHeaderList = loopHeaderMap.tryGetValue(getBlock(*inst)))
+ {
+ _processInstruction(dominatorTree, instAfterParam, *inst, *loopHeaderList, workList);
+ }
+ }
+}
+
+// Check if the instruction is used outside of the loop.
+// The loopHeaderList contains all the loop headers where the original instruction is defined.
+// So we if the block of the user instruction is dominated by the break block of the loop header,
+// it means that it was out of the loop, so it's out of the scope of the loop.
+// Note the reason we use the loopHeaderList is because there could be nested loops, so we need to
+// check all the loop headers from inner to outer.
+bool VariableScopeCorrectionContext::_isOutOfScopeUse(IRInst * userInst, IRDominatorTree* domTree, const List<IRLoop*>& loopHeaderList)
+{
+ if (auto block = getBlock(userInst))
+ {
+ // If the use site of this instruction is dominated by the break block, it means that the
+ // instruction is used after the break block, so we need to make that instruction available globally.
+ // By doing so, we record all the users of this instructions.
+ for(auto loopHeader : loopHeaderList)
+ {
+ auto breakBlock = loopHeader->getBreakBlock();
+ if (domTree->dominates(breakBlock, block))
+ {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+void VariableScopeCorrectionContext::_processInstruction(IRDominatorTree* dominatorTree, IRInst* instAfterParam,
+ IRInst* originInst, const List<IRLoop*>& loopHeaderList, List<IRInst*>& workList)
+{
+ List<IRUse*> outOfScopeUses;
+ for (auto use = originInst->firstUse; use; use=use->nextUse)
+ {
+ if(_isOutOfScopeUse(use->getUser(), dominatorTree, loopHeaderList))
+ {
+ outOfScopeUses.add(use);
+ }
+ }
+
+ if (outOfScopeUses.getCount() == 0)
+ return;
+
+ if(_isStorableType(originInst->getDataType()))
+ {
+ _processStorableInst(instAfterParam, originInst, outOfScopeUses);
+ }
+ else
+ {
+ _processUnstorableInst(originInst, outOfScopeUses);
+ // After processing the user, we need to add operands of the instruction to the worklist
+ // for later processing.
+ for(UInt idx = 0; idx < originInst->getOperandCount(); idx++)
+ {
+ workList.add(originInst->getOperand(idx));
+ }
+ }
+}
+
+void VariableScopeCorrectionContext::_processStorableInst(IRInst* insertLoc, IRInst* inst, const List<IRUse*>& outOfScopeUses)
+{
+ auto type = inst->getDataType();
+ // store instruction must have a result type
+ SLANG_ASSERT(type);
+
+ // declare a new variable at the beginning of the function used to store the result of the instruction
+ m_builder.setInsertBefore(insertLoc);
+ auto dstPtr = m_builder.emitVar(type);
+
+ // insert a store instruction after the instruction
+ m_builder.setInsertAfter(inst);
+ m_builder.emitStore(dstPtr, inst);
+
+ // last, replace operands in the use site instruction with the new variable
+ // Note, because "dstPtr" is a pointer type, we have to insert a load(dstPtr) instruction before use it.
+ // Simply replace any operand with pointer could generate error code.
+ for (auto use : outOfScopeUses)
+ {
+ m_builder.setInsertBefore(use->getUser());
+ auto loadInst = m_builder.emitLoad(type, dstPtr);
+ m_builder.replaceOperand(use, loadInst);
+ }
+}
+
+void VariableScopeCorrectionContext::_processUnstorableInst(IRInst* inst, const List<IRUse*>& outOfScopeUsers)
+{
+ IRCloneEnv cloneEnv;
+ auto clonedInst = cloneInst(&cloneEnv, &m_builder, inst);
+
+ for (auto user : outOfScopeUsers)
+ {
+ // duplicate the invisible instruction and insert it right before the use site,
+ // then replace the operand with the duplicated instruction
+ clonedInst->insertBefore(user->getUser());
+ m_builder.replaceOperand(user, clonedInst);
+ }
+}
+
+bool VariableScopeCorrectionContext::_isStorableType(IRType* type)
+{
+ if (!type)
+ return false;
+
+ // C/CPP/CUDA can store any type.
+ if (isCPUTarget(m_targetReq) || isCUDATarget(m_targetReq))
+ return true;
+
+ if (as<IRBasicType>(type))
+ return true;
+
+ switch(type->getOp())
+ {
+ case kIROp_VectorType:
+ case kIROp_MatrixType:
+ case kIROp_StructType:
+ return true;
+ case kIROp_ArrayType:
+ {
+ if (auto arrayType = as<IRArrayTypeBase>(type))
+ return _isStorableType(arrayType->getElementType());
+ else
+ return false;
+ }
+ case kIROp_UnsizedArrayType:
+ return false;
+ default:
+ return false;
+ }
+}
+
+} // anonymous
+
+void applyVariableScopeCorrection(IRModule* module, TargetRequest* targetReq)
+{
+ VariableScopeCorrectionContext context(module, targetReq);
+
+ context.processModule();
+}
+
+} // namespace Slang
+
diff --git a/source/slang/slang-ir-variable-scope-correction.h b/source/slang/slang-ir-variable-scope-correction.h
new file mode 100644
index 000000000..5f958f9d0
--- /dev/null
+++ b/source/slang/slang-ir-variable-scope-correction.h
@@ -0,0 +1,35 @@
+// slang-ir-variable-scope-correction.h
+#ifndef SLANG_IR_VARIABLE_SCOPE_CORRECTION_H
+#define SLANG_IR_VARIABLE_SCOPE_CORRECTION_H
+
+namespace Slang
+{
+
+struct IRModule;
+
+/// This pass correct the scope of variables in loop regions
+///
+/// In the IR optimization pass, we turn all the loop to do-while loop form.
+/// But in the do-while loop form, the loop body block is dominating the
+/// blocks after the loop break block. E.g.
+///
+/// do {
+/// A
+/// } while (cond);
+/// B
+///
+/// In the above example, the block A is dominating block B. This assumption
+/// is fine for SPIRV and IR code, however, it's incorrect for all the other
+/// language targets (e.g. c/c++/cuda/glsl/hlsl) because the instructions defined
+/// in the block A are not visible from block B. Therefore, when translating to
+/// other textual language, there could be issue for the variables scope.
+///
+/// To fix this issue, we first detect the instructions that are defined
+/// inside the loop block (block A), then check if these instructions are used after
+/// the break block (block B). If so, we duplicate these instructions right before
+/// their users such that we can make those instructions available globally.
+void applyVariableScopeCorrection(IRModule* module, TargetRequest* targetReq);
+
+}
+
+#endif // SLANG_IR_VARIABLE_SCOPE_CORRECTION_H