summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir.h
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-16 13:55:32 -0800
committerGitHub <noreply@github.com>2023-02-16 13:55:32 -0800
commit4c4826d47eeef4675daae4ae53ff76f4d5ebd84a (patch)
treeed4af0ded878e4f06e9641ce61d26ffd7c89ccbc /source/slang/slang-ir.h
parenteda88e513e8b1e2abc05e9dc8555f237d96472df (diff)
Overhaul global inst deduplication and cpp/cuda backend. (#2654)
* Overhaul global inst deduplication and cpp/cuda backend. * Update IR documentation. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir.h')
-rw-r--r--source/slang/slang-ir.h233
1 files changed, 226 insertions, 7 deletions
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 41b140972..9b8aa5cb7 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -37,12 +37,14 @@ enum : IROpFlags
kIROpFlags_None = 0,
kIROpFlag_Parent = 1 << 0, ///< This op is a parent op
kIROpFlag_UseOther = 1 << 1, ///< If set this op can use 'other bits' to store information
+ kIROpFlag_Hoistable = 1 << 2, ///< If set this op is a hoistable inst that needs to be deduplicated.
+ kIROpFlag_Global = 1 << 3, ///< If set this op should always be hoisted but should never be deduplicated.
};
/* Bit usage of IROp is a follows
MainOp | Other
-Bit range: 0-7 | Remaining bits
+Bit range: 0-10 | Remaining bits
For doing range checks (for example for doing isa tests), the value is masked by kIROpMeta_OpMask, such that the Other bits don't interfere.
The other bits can be used for storage for anything that needs to identify as a different 'op' or 'type'. It is currently
@@ -92,6 +94,9 @@ struct IROpInfo
// Flags to control how we emit additional info
IROpFlags flags;
+
+ bool isHoistable() const { return (flags & kIROpFlag_Hoistable) != 0; }
+ bool isGlobal() const { return (flags & kIROpFlag_Global) != 0; }
};
// Look up the info for an op
@@ -206,6 +211,43 @@ struct IRInstList : IRInstListBase
};
template<typename T>
+struct IRModifiableInstList
+{
+ IRInst* parent;
+ List<IRInst*> workList;
+
+ IRModifiableInstList() {}
+
+ IRModifiableInstList(T* parent, T* first, T* last);
+
+ T* getFirst() { return workList.getCount() ? (T*)workList.getFirst() : nullptr; }
+ T* getLast() { return workList.getCount() ? (T*)workList.getLast() : nullptr; }
+
+ struct Iterator
+ {
+ IRModifiableInstList<T>* list;
+ Index position = 0;
+
+ Iterator() {}
+ Iterator(IRModifiableInstList<T>* inList, Index inPos) : list(inList), position(inPos) {}
+
+ T* operator*()
+ {
+ return (T*)(list->workList[position]);
+ }
+ void operator++();
+
+ bool operator!=(Iterator const& i)
+ {
+ return i.list != list || i.position != position;
+ }
+ };
+
+ Iterator begin() { return Iterator(this, 0); }
+ Iterator end() { return Iterator(this, workList.getCount()); }
+};
+
+template<typename T>
struct IRFilteredInstList : IRInstListBase
{
IRFilteredInstList() {}
@@ -591,6 +633,14 @@ struct IRInst
getLastChild());
}
+ IRModifiableInstList<IRInst> getModifiableChildren()
+ {
+ return IRModifiableInstList<IRInst>(
+ this,
+ getFirstChild(),
+ getLastChild());
+ }
+
/// A doubly-linked list containing any decorations and then any children of this instruction.
///
/// We store both the decorations and children of an instruction
@@ -607,7 +657,13 @@ struct IRInst
IRInst* getFirstDecorationOrChild() { return m_decorationsAndChildren.first; }
IRInst* getLastDecorationOrChild() { return m_decorationsAndChildren.last; }
IRInstListBase getDecorationsAndChildren() { return m_decorationsAndChildren; }
-
+ IRModifiableInstList<IRInst> getModifiableDecorationsAndChildren()
+ {
+ return IRModifiableInstList<IRInst>(
+ this,
+ m_decorationsAndChildren.first,
+ m_decorationsAndChildren.last);
+ }
void removeAndDeallocateAllDecorationsAndChildren();
#ifdef SLANG_ENABLE_IR_BREAK_ALLOC
@@ -647,6 +703,12 @@ struct IRInst
getOperands()[index].set(value);
}
+ void unsafeSetOperand(UInt index, IRInst* value)
+ {
+ SLANG_ASSERT(getOperands()[index].user != nullptr);
+ getOperands()[index].init(this, value);
+ }
+
//
@@ -773,6 +835,39 @@ typename IRInstList<T>::Iterator IRInstList<T>::end()
}
template<typename T>
+IRModifiableInstList<T>::IRModifiableInstList(T* inParent, T* first, T* last)
+{
+ parent = inParent;
+ for (auto item = first; item; item = item->next)
+ {
+ workList.add(item);
+ if (item == last)
+ break;
+ }
+}
+
+template<typename T>
+void IRModifiableInstList<T>::Iterator::operator++()
+{
+ position++;
+ while (position < list->workList.getCount())
+ {
+ auto inst = list->workList[position];
+ if (!as<T>(inst))
+ {
+ // Skip insts that are not of type T.
+ }
+ else if (list->parent != inst->parent)
+ {
+ // Skip insts that are no longer in its original parent.
+ }
+ else
+ break;
+ position++;
+ }
+}
+
+template<typename T>
IRFilteredInstList<T>::IRFilteredInstList(IRInst* fst, IRInst* lst)
{
first = fst;
@@ -1796,6 +1891,104 @@ struct IRModuleInst : IRInst
IR_LEAF_ISA(Module)
};
+struct IRModule;
+
+// Description of an instruction to be used for global value numbering
+struct IRInstKey
+{
+ IRInst* inst;
+
+ HashCode getHashCode();
+};
+
+bool operator==(IRInstKey const& left, IRInstKey const& right);
+
+struct IRConstantKey
+{
+ IRConstant* inst;
+
+ bool operator==(const IRConstantKey& rhs) const { return inst->equal(rhs.inst); }
+ HashCode getHashCode() const { return inst->getHashCode(); }
+};
+
+struct SharedIRBuilder
+{
+public:
+ SharedIRBuilder()
+ {}
+
+ explicit SharedIRBuilder(IRModule* module)
+ {
+ init(module);
+ }
+
+ void init(IRModule* module);
+
+ IRModule* getModule()
+ {
+ return m_module;
+ }
+
+ Session* getSession()
+ {
+ return m_session;
+ }
+
+ void insertBlockAlongEdge(IREdge const& edge);
+
+ // Rebuilds `globalValueNumberingMap`. This is necessary if any existing
+ // keys are modified (thus its hash code is changed).
+ void deduplicateAndRebuildGlobalNumberingMap();
+
+ // Replaces all uses of oldInst with newInst, and ensures the global numbering map is valid after the replacement.
+ void replaceGlobalInst(IRInst* oldInst, IRInst* newInst);
+
+ void removeHoistableInstFromGlobalNumberingMap(IRInst* inst);
+
+ void tryHoistInst(IRInst* inst);
+
+ typedef Dictionary<IRInstKey, IRInst*> GlobalValueNumberingMap;
+ typedef Dictionary<IRConstantKey, IRConstant*> ConstantMap;
+
+ GlobalValueNumberingMap& getGlobalValueNumberingMap() { return m_globalValueNumberingMap; }
+ Dictionary<IRInst*, IRInst*>& getInstReplacementMap() { return m_instReplacementMap; }
+
+ void _addGlobalNumberingEntry(IRInst* inst)
+ {
+ m_globalValueNumberingMap.Add(IRInstKey{ inst }, inst);
+ m_instReplacementMap.Remove(inst);
+ tryHoistInst(inst);
+ }
+ void _removeGlobalNumberingEntry(IRInst* inst)
+ {
+ IRInst* value = nullptr;
+ if (m_globalValueNumberingMap.TryGetValue(IRInstKey{ inst }, value))
+ {
+ if (value == inst)
+ {
+ m_globalValueNumberingMap.Remove(IRInstKey{ inst });
+ }
+ }
+ }
+
+ ConstantMap& getConstantMap() { return m_constantMap; }
+
+private:
+ // The module that will own all of the IR
+ IRModule* m_module;
+
+ // The parent compilation session
+ Session* m_session;
+
+ GlobalValueNumberingMap m_globalValueNumberingMap;
+
+ // Duplicate insts that are still alive and needs to be replaced in m_globalValueNumberMap
+ // when used as an operand to create another inst.
+ Dictionary<IRInst*, IRInst*> m_instReplacementMap;
+
+ ConstantMap m_constantMap;
+};
+
struct IRModule : RefObject
{
public:
@@ -1810,6 +2003,8 @@ public:
SLANG_FORCE_INLINE IRModuleInst* getModuleInst() const { return m_moduleInst; }
SLANG_FORCE_INLINE MemoryArena& getMemoryArena() { return m_memoryArena; }
+ SharedIRBuilder* getSharedBuilder() const { return &m_sharedBuilder; }
+
IRInstListBase getGlobalInsts() const { return getModuleInst()->getChildren(); }
/// Create an empty instruction with the `op` opcode and space for
@@ -1853,6 +2048,7 @@ private:
IRModule(Session* session)
: m_session(session)
, m_memoryArena(kMemoryArenaBlockSize)
+ , m_sharedBuilder(this)
{
}
@@ -1870,6 +2066,9 @@ private:
/// The memory arena from which all IR instructions (and any associated state) in this module are allocated.
MemoryArena m_memoryArena;
+
+ /// Shared contexts for constructing and maintaining the IR.
+ mutable SharedIRBuilder m_sharedBuilder;
};
struct IRSpecializationDictionaryItem : public IRInst
@@ -1943,13 +2142,17 @@ uint32_t& _debugGetIRAllocCounter();
// TODO: Ellie, comment and move somewhere more appropriate?
template<typename I = IRInst, typename F>
-static void traverseUses(IRInst* inst, F f)
+static void traverseUsers(IRInst* inst, F f)
{
- auto n = inst->firstUse;
- IRUse* u;
- while((u = n) != nullptr)
+ List<IRUse*> uses;
+ for (auto use = inst->firstUse; use; use = use->nextUse)
{
- n = u->nextUse;
+ uses.add(use);
+ }
+ for (auto u : uses)
+ {
+ if (u->usedValue != inst)
+ continue;
if(auto s = as<I>(u->getUser()))
{
f(s);
@@ -1957,6 +2160,22 @@ static void traverseUses(IRInst* inst, F f)
}
}
+template<typename F>
+static void traverseUses(IRInst* inst, F f)
+{
+ List<IRUse*> uses;
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ uses.add(use);
+ }
+ for (auto u : uses)
+ {
+ if (u->usedValue != inst)
+ continue;
+ f(u);
+ }
+}
+
namespace detail
{
// A helper to get the singular pointer argument of something callable