diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-16 13:55:32 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-16 13:55:32 -0800 |
| commit | 4c4826d47eeef4675daae4ae53ff76f4d5ebd84a (patch) | |
| tree | ed4af0ded878e4f06e9641ce61d26ffd7c89ccbc /source/slang/slang-ir.h | |
| parent | eda88e513e8b1e2abc05e9dc8555f237d96472df (diff) | |
Overhaul global inst deduplication and cpp/cuda backend. (#2654)
* Overhaul global inst deduplication and cpp/cuda backend.
* Update IR documentation.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir.h')
| -rw-r--r-- | source/slang/slang-ir.h | 233 |
1 files changed, 226 insertions, 7 deletions
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 41b140972..9b8aa5cb7 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -37,12 +37,14 @@ enum : IROpFlags kIROpFlags_None = 0, kIROpFlag_Parent = 1 << 0, ///< This op is a parent op kIROpFlag_UseOther = 1 << 1, ///< If set this op can use 'other bits' to store information + kIROpFlag_Hoistable = 1 << 2, ///< If set this op is a hoistable inst that needs to be deduplicated. + kIROpFlag_Global = 1 << 3, ///< If set this op should always be hoisted but should never be deduplicated. }; /* Bit usage of IROp is a follows MainOp | Other -Bit range: 0-7 | Remaining bits +Bit range: 0-10 | Remaining bits For doing range checks (for example for doing isa tests), the value is masked by kIROpMeta_OpMask, such that the Other bits don't interfere. The other bits can be used for storage for anything that needs to identify as a different 'op' or 'type'. It is currently @@ -92,6 +94,9 @@ struct IROpInfo // Flags to control how we emit additional info IROpFlags flags; + + bool isHoistable() const { return (flags & kIROpFlag_Hoistable) != 0; } + bool isGlobal() const { return (flags & kIROpFlag_Global) != 0; } }; // Look up the info for an op @@ -206,6 +211,43 @@ struct IRInstList : IRInstListBase }; template<typename T> +struct IRModifiableInstList +{ + IRInst* parent; + List<IRInst*> workList; + + IRModifiableInstList() {} + + IRModifiableInstList(T* parent, T* first, T* last); + + T* getFirst() { return workList.getCount() ? (T*)workList.getFirst() : nullptr; } + T* getLast() { return workList.getCount() ? (T*)workList.getLast() : nullptr; } + + struct Iterator + { + IRModifiableInstList<T>* list; + Index position = 0; + + Iterator() {} + Iterator(IRModifiableInstList<T>* inList, Index inPos) : list(inList), position(inPos) {} + + T* operator*() + { + return (T*)(list->workList[position]); + } + void operator++(); + + bool operator!=(Iterator const& i) + { + return i.list != list || i.position != position; + } + }; + + Iterator begin() { return Iterator(this, 0); } + Iterator end() { return Iterator(this, workList.getCount()); } +}; + +template<typename T> struct IRFilteredInstList : IRInstListBase { IRFilteredInstList() {} @@ -591,6 +633,14 @@ struct IRInst getLastChild()); } + IRModifiableInstList<IRInst> getModifiableChildren() + { + return IRModifiableInstList<IRInst>( + this, + getFirstChild(), + getLastChild()); + } + /// A doubly-linked list containing any decorations and then any children of this instruction. /// /// We store both the decorations and children of an instruction @@ -607,7 +657,13 @@ struct IRInst IRInst* getFirstDecorationOrChild() { return m_decorationsAndChildren.first; } IRInst* getLastDecorationOrChild() { return m_decorationsAndChildren.last; } IRInstListBase getDecorationsAndChildren() { return m_decorationsAndChildren; } - + IRModifiableInstList<IRInst> getModifiableDecorationsAndChildren() + { + return IRModifiableInstList<IRInst>( + this, + m_decorationsAndChildren.first, + m_decorationsAndChildren.last); + } void removeAndDeallocateAllDecorationsAndChildren(); #ifdef SLANG_ENABLE_IR_BREAK_ALLOC @@ -647,6 +703,12 @@ struct IRInst getOperands()[index].set(value); } + void unsafeSetOperand(UInt index, IRInst* value) + { + SLANG_ASSERT(getOperands()[index].user != nullptr); + getOperands()[index].init(this, value); + } + // @@ -773,6 +835,39 @@ typename IRInstList<T>::Iterator IRInstList<T>::end() } template<typename T> +IRModifiableInstList<T>::IRModifiableInstList(T* inParent, T* first, T* last) +{ + parent = inParent; + for (auto item = first; item; item = item->next) + { + workList.add(item); + if (item == last) + break; + } +} + +template<typename T> +void IRModifiableInstList<T>::Iterator::operator++() +{ + position++; + while (position < list->workList.getCount()) + { + auto inst = list->workList[position]; + if (!as<T>(inst)) + { + // Skip insts that are not of type T. + } + else if (list->parent != inst->parent) + { + // Skip insts that are no longer in its original parent. + } + else + break; + position++; + } +} + +template<typename T> IRFilteredInstList<T>::IRFilteredInstList(IRInst* fst, IRInst* lst) { first = fst; @@ -1796,6 +1891,104 @@ struct IRModuleInst : IRInst IR_LEAF_ISA(Module) }; +struct IRModule; + +// Description of an instruction to be used for global value numbering +struct IRInstKey +{ + IRInst* inst; + + HashCode getHashCode(); +}; + +bool operator==(IRInstKey const& left, IRInstKey const& right); + +struct IRConstantKey +{ + IRConstant* inst; + + bool operator==(const IRConstantKey& rhs) const { return inst->equal(rhs.inst); } + HashCode getHashCode() const { return inst->getHashCode(); } +}; + +struct SharedIRBuilder +{ +public: + SharedIRBuilder() + {} + + explicit SharedIRBuilder(IRModule* module) + { + init(module); + } + + void init(IRModule* module); + + IRModule* getModule() + { + return m_module; + } + + Session* getSession() + { + return m_session; + } + + void insertBlockAlongEdge(IREdge const& edge); + + // Rebuilds `globalValueNumberingMap`. This is necessary if any existing + // keys are modified (thus its hash code is changed). + void deduplicateAndRebuildGlobalNumberingMap(); + + // Replaces all uses of oldInst with newInst, and ensures the global numbering map is valid after the replacement. + void replaceGlobalInst(IRInst* oldInst, IRInst* newInst); + + void removeHoistableInstFromGlobalNumberingMap(IRInst* inst); + + void tryHoistInst(IRInst* inst); + + typedef Dictionary<IRInstKey, IRInst*> GlobalValueNumberingMap; + typedef Dictionary<IRConstantKey, IRConstant*> ConstantMap; + + GlobalValueNumberingMap& getGlobalValueNumberingMap() { return m_globalValueNumberingMap; } + Dictionary<IRInst*, IRInst*>& getInstReplacementMap() { return m_instReplacementMap; } + + void _addGlobalNumberingEntry(IRInst* inst) + { + m_globalValueNumberingMap.Add(IRInstKey{ inst }, inst); + m_instReplacementMap.Remove(inst); + tryHoistInst(inst); + } + void _removeGlobalNumberingEntry(IRInst* inst) + { + IRInst* value = nullptr; + if (m_globalValueNumberingMap.TryGetValue(IRInstKey{ inst }, value)) + { + if (value == inst) + { + m_globalValueNumberingMap.Remove(IRInstKey{ inst }); + } + } + } + + ConstantMap& getConstantMap() { return m_constantMap; } + +private: + // The module that will own all of the IR + IRModule* m_module; + + // The parent compilation session + Session* m_session; + + GlobalValueNumberingMap m_globalValueNumberingMap; + + // Duplicate insts that are still alive and needs to be replaced in m_globalValueNumberMap + // when used as an operand to create another inst. + Dictionary<IRInst*, IRInst*> m_instReplacementMap; + + ConstantMap m_constantMap; +}; + struct IRModule : RefObject { public: @@ -1810,6 +2003,8 @@ public: SLANG_FORCE_INLINE IRModuleInst* getModuleInst() const { return m_moduleInst; } SLANG_FORCE_INLINE MemoryArena& getMemoryArena() { return m_memoryArena; } + SharedIRBuilder* getSharedBuilder() const { return &m_sharedBuilder; } + IRInstListBase getGlobalInsts() const { return getModuleInst()->getChildren(); } /// Create an empty instruction with the `op` opcode and space for @@ -1853,6 +2048,7 @@ private: IRModule(Session* session) : m_session(session) , m_memoryArena(kMemoryArenaBlockSize) + , m_sharedBuilder(this) { } @@ -1870,6 +2066,9 @@ private: /// The memory arena from which all IR instructions (and any associated state) in this module are allocated. MemoryArena m_memoryArena; + + /// Shared contexts for constructing and maintaining the IR. + mutable SharedIRBuilder m_sharedBuilder; }; struct IRSpecializationDictionaryItem : public IRInst @@ -1943,13 +2142,17 @@ uint32_t& _debugGetIRAllocCounter(); // TODO: Ellie, comment and move somewhere more appropriate? template<typename I = IRInst, typename F> -static void traverseUses(IRInst* inst, F f) +static void traverseUsers(IRInst* inst, F f) { - auto n = inst->firstUse; - IRUse* u; - while((u = n) != nullptr) + List<IRUse*> uses; + for (auto use = inst->firstUse; use; use = use->nextUse) { - n = u->nextUse; + uses.add(use); + } + for (auto u : uses) + { + if (u->usedValue != inst) + continue; if(auto s = as<I>(u->getUser())) { f(s); @@ -1957,6 +2160,22 @@ static void traverseUses(IRInst* inst, F f) } } +template<typename F> +static void traverseUses(IRInst* inst, F f) +{ + List<IRUse*> uses; + for (auto use = inst->firstUse; use; use = use->nextUse) + { + uses.add(use); + } + for (auto u : uses) + { + if (u->usedValue != inst) + continue; + f(u); + } +} + namespace detail { // A helper to get the singular pointer argument of something callable |
