summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-spirv-legalize.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-23 06:58:50 -0700
committerGitHub <noreply@github.com>2023-08-23 21:58:50 +0800
commitc515bf9edf0ceefa9a0c9b36626ea7c8f72ce36f (patch)
tree670a3a80f0f60b7be7fd50e40d9d088f5e7607a7 /source/slang/slang-ir-spirv-legalize.cpp
parent6437c38e0a3c2c1daf36cb5e543dc0b467fa4b15 (diff)
Misc. SPIRV Fixes. (#3146)
* Lower all ByteAddressBuffer uses for SPIRV. * Misc. SPIRV Fixes. --------- Co-authored-by: Yong He <yhe@nvidia.com> Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-spirv-legalize.cpp')
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp166
1 files changed, 166 insertions, 0 deletions
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index f6294e2ba..36fdbd56a 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -11,6 +11,7 @@
#include "slang-ir-lower-buffer-element-type.h"
#include "slang-ir-layout.h"
#include "slang-ir-util.h"
+#include "slang-ir-dominators.h"
namespace Slang
{
@@ -245,7 +246,81 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
inst->removeAndDeallocate();
addUsersToWorkList(newCall);
}
+ return;
+ }
+
+ // According to SPIRV spec, the if the operands of a call has pointer
+ // type, then it can only be a memory-object. This means that if the
+ // pointer is a result of `getElementPtr`, we cannot use it as an
+ // argument. In this case, we have to allocate a temp var to pass the
+ // value, and write them back to the original pointer after the call.
+ //
+ // > SPIRV Spec section 2.16.1:
+ // > - Any pointer operand to an OpFunctionCall must be a memory object
+ // > declaration, or
+ // > - a pointer to an element in an array that is a memory object
+ // > declaration, where the element type is OpTypeSampler or OpTypeImage.
+ //
+ List<IRInst*> newArgs;
+ struct WriteBackPair { IRInst* originalAddrArg; IRInst* tempVar; };
+ List<WriteBackPair> writeBacks;
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ for (UInt i = 0; i < inst->getArgCount(); i++)
+ {
+ auto arg = inst->getArg(i);
+ auto ptrType = as<IRPtrTypeBase>(arg->getDataType());
+ if (!as<IRPtrTypeBase>(arg->getDataType()))
+ {
+ newArgs.add(arg);
+ continue;
+ }
+ // Is the arg already a memory-object by SPIRV definition?
+ // If so we don't need to allocate a temp var.
+ switch (arg->getOp())
+ {
+ case kIROp_Var:
+ case kIROp_GlobalVar:
+ newArgs.add(arg);
+ continue;
+ case kIROp_Param:
+ if (arg->getParent() == getParentFunc(arg)->getFirstBlock())
+ {
+ newArgs.add(arg);
+ continue;
+ }
+ break;
+ default:
+ break;
+ }
+ auto root = getRootAddr(arg);
+ if (root)
+ {
+ switch (root->getOp())
+ {
+ case kIROp_RWStructuredBufferGetElementPtr:
+ newArgs.add(arg);
+ continue;
+ }
+ }
+
+ // If we reach here, we need to allocate a temp var.
+ auto tempVar = builder.emitVar(ptrType->getValueType());
+ auto load = builder.emitLoad(arg);
+ builder.emitStore(tempVar, load);
+ newArgs.add(tempVar);
+ writeBacks.add(WriteBackPair{ arg, tempVar });
+ }
+ SLANG_ASSERT((UInt)newArgs.getCount() == inst->getArgCount());
+ auto newCall = builder.emitCallInst(inst->getFullType(), inst->getCallee(), newArgs);
+ for (auto wb : writeBacks)
+ {
+ auto newVal = builder.emitLoad(wb.tempVar);
+ builder.emitStore(wb.originalAddrArg, newVal);
}
+ inst->replaceUsesWith(newCall);
+ inst->removeAndDeallocate();
+ addUsersToWorkList(newCall);
}
Dictionary<IRInst*, IRInst*> m_mapArrayValueToVar;
@@ -430,6 +505,33 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
addUsersToWorkList(ptrType);
}
+ void duplicateMergeBlockIfNeeded(IRUse* breakBlockUse)
+ {
+ auto breakBlock = as<IRBlock>(breakBlockUse->get());
+ if (breakBlock->getFirstInst()->getOp() != kIROp_Unreachable)
+ {
+ return;
+ }
+ bool hasMoreThanOneUser = false;
+ for (auto use = breakBlock->firstUse; use; use = use->nextUse)
+ {
+ if (use->getUser() != breakBlockUse->getUser())
+ {
+ hasMoreThanOneUser = true;
+ break;
+ }
+ }
+ if (!hasMoreThanOneUser)
+ return;
+
+ // Create a duplicate block for this use.
+ IRBuilder builder(breakBlock);
+ builder.setInsertBefore(breakBlock);
+ auto block = builder.emitBlock();
+ builder.emitUnreachable();
+ breakBlockUse->set(block);
+ }
+
void processLoop(IRLoop* loop)
{
@@ -545,6 +647,56 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
builder.emitBranch(t, ps.getCount(), ps.getBuffer());
}
}
+ duplicateMergeBlockIfNeeded(&loop->breakBlock);
+ }
+
+ void processIfElse(IRIfElse* inst)
+ {
+ duplicateMergeBlockIfNeeded(&inst->afterBlock);
+
+ // SPIRV does not allow using merge block directly as true/false block,
+ // so we need to create an intermediate block if this is the case.
+ IRBuilder builder(inst);
+ if (inst->getTrueBlock() == inst->getAfterBlock())
+ {
+ builder.setInsertBefore(inst->getAfterBlock());
+ auto newBlock = builder.emitBlock();
+ builder.emitBranch(inst->getAfterBlock());
+ inst->trueBlock.set(newBlock);
+ }
+ if (inst->getFalseBlock() == inst->getAfterBlock())
+ {
+ builder.setInsertBefore(inst->getAfterBlock());
+ auto newBlock = builder.emitBlock();
+ builder.emitBranch(inst->getAfterBlock());
+ inst->falseBlock.set(newBlock);
+ }
+ }
+
+ void processSwitch(IRSwitch* inst)
+ {
+ duplicateMergeBlockIfNeeded(&inst->breakLabel);
+
+ // SPIRV does not allow using merge block directly as case block,
+ // so we need to create an intermediate block if this is the case.
+ IRBuilder builder(inst);
+ if (inst->getDefaultLabel() == inst->getBreakLabel())
+ {
+ builder.setInsertBefore(inst->getBreakLabel());
+ auto newBlock = builder.emitBlock();
+ builder.emitBranch(inst->getBreakLabel());
+ inst->defaultLabel.set(newBlock);
+ }
+ for (UInt i = 0; i < inst->getCaseCount(); i++)
+ {
+ if (inst->getCaseLabel(i) == inst->getBreakLabel())
+ {
+ builder.setInsertBefore(inst->getBreakLabel());
+ auto newBlock = builder.emitBlock();
+ builder.emitBranch(inst->getBreakLabel());
+ inst->getCaseLabelUse(i)->set(newBlock);
+ }
+ }
}
void processModule()
@@ -593,6 +745,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
case kIROp_loop:
processLoop(as<IRLoop>(inst));
break;
+ case kIROp_ifElse:
+ processIfElse(as<IRIfElse>(inst));
+ break;
+ case kIROp_Switch:
+ processSwitch(as<IRSwitch>(inst));
+ break;
default:
for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
{
@@ -601,6 +759,14 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
break;
}
}
+
+ // SPIRV requires a dominator block to appear before dominated blocks.
+ // After legalizing the control flow, we need to sort our blocks to ensure this is true.
+ for (auto globalInst : m_module->getGlobalInsts())
+ {
+ if (auto func = as<IRGlobalValueWithCode>(globalInst))
+ sortBlocksInFunc(func);
+ }
}
};