diff options
| author | Yong He <yonghe@outlook.com> | 2021-08-12 13:14:15 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-08-12 13:14:15 -0700 |
| commit | 6406523511037987d8b8ab881aea41389afd57eb (patch) | |
| tree | 79f24b6cba377340c2f4d3dcf9fed78fc586f3e0 | |
| parent | 389d21d982da34815b65b10cae63088c397eecc8 (diff) | |
Further implementation of SPIRV direct emit. (#1920)
* Further implementation of SPIRV direct emit.
This change implements:
- Struct, Vector, Matrix and Unsized Array types.
- Basic arithmetic opcodes, vector construct, swizzle etc.
- getElementPtr, getElement, fieldAddress, extractField.
- SPIRV target intrinsics with SPIRV asm code in stdlib.
- RWStructuredBuffer and StructuredBuffer.
- Pointer storage class propagation.
- Control flow.
* Fix.
29 files changed, 1989 insertions, 106 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index f175d6a31..0ccd5ed70 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -209,6 +209,7 @@ <ClInclude Include="..\..\..\source\slang\slang-diagnostics.h" /> <ClInclude Include="..\..\..\source\slang\slang-doc-extractor.h" /> <ClInclude Include="..\..\..\source\slang\slang-doc-markdown-writer.h" /> + <ClInclude Include="..\..\..\source\slang\slang-emit-base.h" /> <ClInclude Include="..\..\..\source\slang\slang-emit-c-like.h" /> <ClInclude Include="..\..\..\source\slang\slang-emit-cpp.h" /> <ClInclude Include="..\..\..\source\slang\slang-emit-cuda.h" /> @@ -263,6 +264,8 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-specialize-function-call.h" /> <ClInclude Include="..\..\..\source\slang\slang-ir-specialize-resources.h" /> <ClInclude Include="..\..\..\source\slang\slang-ir-specialize.h" /> + <ClInclude Include="..\..\..\source\slang\slang-ir-spirv-legalize.h" /> + <ClInclude Include="..\..\..\source\slang\slang-ir-spirv-snippet.h" /> <ClInclude Include="..\..\..\source\slang\slang-ir-ssa.h" /> <ClInclude Include="..\..\..\source\slang\slang-ir-string-hash.h" /> <ClInclude Include="..\..\..\source\slang\slang-ir-strip-witness-tables.h" /> @@ -335,6 +338,7 @@ <ClCompile Include="..\..\..\source\slang\slang-diagnostics.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-doc-extractor.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-doc-markdown-writer.cpp" /> + <ClCompile Include="..\..\..\source\slang\slang-emit-base.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-emit-c-like.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-emit-cpp.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-emit-cuda.cpp" /> @@ -389,6 +393,8 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-specialize-function-call.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-ir-specialize-resources.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-ir-specialize.cpp" /> + <ClCompile Include="..\..\..\source\slang\slang-ir-spirv-legalize.cpp" /> + <ClCompile Include="..\..\..\source\slang\slang-ir-spirv-snippet.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-ir-ssa.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-ir-string-hash.cpp" /> <ClCompile Include="..\..\..\source\slang\slang-ir-strip-witness-tables.cpp" /> diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index 1697a385c..a7affb00a 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -78,6 +78,9 @@ <ClInclude Include="..\..\..\source\slang\slang-doc-markdown-writer.h"> <Filter>Header Files</Filter> </ClInclude> + <ClInclude Include="..\..\..\source\slang\slang-emit-base.h"> + <Filter>Header Files</Filter> + </ClInclude> <ClInclude Include="..\..\..\source\slang\slang-emit-c-like.h"> <Filter>Header Files</Filter> </ClInclude> @@ -240,6 +243,12 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-specialize.h"> <Filter>Header Files</Filter> </ClInclude> + <ClInclude Include="..\..\..\source\slang\slang-ir-spirv-legalize.h"> + <Filter>Header Files</Filter> + </ClInclude> + <ClInclude Include="..\..\..\source\slang\slang-ir-spirv-snippet.h"> + <Filter>Header Files</Filter> + </ClInclude> <ClInclude Include="..\..\..\source\slang\slang-ir-ssa.h"> <Filter>Header Files</Filter> </ClInclude> @@ -452,6 +461,9 @@ <ClCompile Include="..\..\..\source\slang\slang-doc-markdown-writer.cpp"> <Filter>Source Files</Filter> </ClCompile> + <ClCompile Include="..\..\..\source\slang\slang-emit-base.cpp"> + <Filter>Source Files</Filter> + </ClCompile> <ClCompile Include="..\..\..\source\slang\slang-emit-c-like.cpp"> <Filter>Source Files</Filter> </ClCompile> @@ -614,6 +626,12 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-specialize.cpp"> <Filter>Source Files</Filter> </ClCompile> + <ClCompile Include="..\..\..\source\slang\slang-ir-spirv-legalize.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + <ClCompile Include="..\..\..\source\slang\slang-ir-spirv-snippet.cpp"> + <Filter>Source Files</Filter> + </ClCompile> <ClCompile Include="..\..\..\source\slang\slang-ir-ssa.cpp"> <Filter>Source Files</Filter> </ClCompile> diff --git a/source/core/slang-token-reader.h b/source/core/slang-token-reader.h index 0d59eea76..26539732c 100644 --- a/source/core/slang-token-reader.h +++ b/source/core/slang-token-reader.h @@ -73,7 +73,7 @@ namespace Misc { TokenType Type = TokenType::Unknown; String Content; CodePosition Position; - TokenFlags flags; + TokenFlags flags = 0; Token() = default; Token(TokenType type, const String & content, int line, int col, int pos, String fileName, TokenFlags flags = 0) : flags(flags) diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index bb573c2b2..dd4f95cf5 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -199,13 +199,15 @@ struct StructuredBuffer out uint numStructs, out uint stride); - __target_intrinsic(glsl, "$0._data[$1]") + __target_intrinsic(glsl, "$0._data[$1]") + __target_intrinsic(spirv_direct, "%addr = 65 resultType*StorageBuffer resultId _0 _1; 61 resultType resultId %addr;") T Load(int location); T Load(int location, out uint status); __subscript(uint index) -> T { __target_intrinsic(glsl, "$0._data[$1]") + __target_intrinsic(spirv_direct, "%addr = 65 resultType*StorageBuffer resultId _0 _1; 61 resultType resultId %addr;") get; }; }; @@ -629,12 +631,14 @@ struct $(item.name) uint IncrementCounter(); __target_intrinsic(glsl, "$0._data[$1]") + __target_intrinsic(spirv_direct, "%addr = 65 resultType*StorageBuffer resultId _0 _1; 61 resultType resultId %addr;") T Load(int location); T Load(int location, out uint status); __subscript(uint index) -> T { __target_intrinsic(glsl, "$0._data[$1]") + __target_intrinsic(spirv_direct, "*StorageBuffer 65 resultType resultId _0 _1") ref; } }; diff --git a/source/slang/slang-capability-defs.h b/source/slang/slang-capability-defs.h index f66add15b..fc60f4dfa 100644 --- a/source/slang/slang-capability-defs.h +++ b/source/slang/slang-capability-defs.h @@ -55,6 +55,7 @@ SLANG_CAPABILITY_ATOM0(GLSL, glsl, Concrete,TargetFormat,0) SLANG_CAPABILITY_ATOM0(C, c, Concrete,TargetFormat,0) SLANG_CAPABILITY_ATOM0(CPP, cpp, Concrete,TargetFormat,0) SLANG_CAPABILITY_ATOM0(CUDA, cuda, Concrete,TargetFormat,0) +SLANG_CAPABILITY_ATOM0(SPIRV_DIRECT, spirv_direct, Concrete, TargetFormat, 0) // We have multiple capabilities for the various SPIR-V versions, // arranged so that they inherit from one another to represent which versions diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index d909a190c..028886c7e 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -1445,6 +1445,7 @@ namespace Slang if (target == CodeGenTarget::SPIRV && compileRequest->shouldEmitSPIRVDirectly) { List<uint8_t> spirv; + targetReq->setDirectSPIRVEmitMode(); SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPointsDirectly(compileRequest, entryPointIndices, targetReq, spirv)); auto spirvBlob = ListBlob::moveCreate(spirv); downstreamResult = new BlobDownstreamCompileResult(DownstreamDiagnostics(), spirvBlob); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 01f23918b..b829fd0ee 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1190,6 +1190,8 @@ namespace Slang return (targetFlags & SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM) != 0; } + void setDirectSPIRVEmitMode(); + Linkage* getLinkage() { return linkage; } CodeGenTarget getTarget() { return format; } Profile getTargetProfile() { return targetProfile; } @@ -1217,6 +1219,7 @@ namespace Slang List<CapabilityAtom> rawCapabilities; CapabilitySet cookedCapabilities; LineDirectiveMode lineDirectiveMode = LineDirectiveMode::Default; + bool m_emitSPIRVDirectly = false; }; /// Are we generating code for a D3D API? diff --git a/source/slang/slang-emit-base.cpp b/source/slang/slang-emit-base.cpp new file mode 100644 index 000000000..d00b723ab --- /dev/null +++ b/source/slang/slang-emit-base.cpp @@ -0,0 +1,55 @@ +#include "slang-emit-base.h" + +namespace Slang +{ + +IRInst* SourceEmitterBase::getSpecializedValue(IRSpecialize* specInst) +{ + auto base = specInst->getBase(); + + // It is possible to have a `specialize(...)` where the first + // operand is also a `specialize(...)`, so that we need to + // look at what declaration is being specialized at the inner + // step to find the one being specialized at the outer step. + // + while (auto baseSpecialize = as<IRSpecialize>(base)) + { + base = getSpecializedValue(baseSpecialize); + } + + auto baseGeneric = as<IRGeneric>(base); + if (!baseGeneric) + return base; + + auto lastBlock = baseGeneric->getLastBlock(); + if (!lastBlock) + return base; + + auto returnInst = as<IRReturnVal>(lastBlock->getTerminator()); + if (!returnInst) + return base; + + return returnInst->getVal(); +} + +void SourceEmitterBase::handleRequiredCapabilities(IRInst* inst) +{ + auto decoratedValue = inst; + while (auto specInst = as<IRSpecialize>(decoratedValue)) + { + decoratedValue = getSpecializedValue(specInst); + } + + handleRequiredCapabilitiesImpl(decoratedValue); +} + +IRVarLayout* SourceEmitterBase::getVarLayout(IRInst* var) +{ + auto decoration = var->findDecoration<IRLayoutDecoration>(); + if (!decoration) + return nullptr; + + return as<IRVarLayout>(decoration->getLayout()); +} + +} diff --git a/source/slang/slang-emit-base.h b/source/slang/slang-emit-base.h new file mode 100644 index 000000000..ffbf56618 --- /dev/null +++ b/source/slang/slang-emit-base.h @@ -0,0 +1,29 @@ +// slang-emit-base.h +#ifndef SLANG_EMIT_BASE_H +#define SLANG_EMIT_BASE_H + +#include "../core/slang-basic.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-restructure.h" + +namespace Slang +{ + +class SourceEmitterBase : public RefObject +{ +public: + IRInst* getSpecializedValue(IRSpecialize* specInst); + + /// Inspect the capabilities required by `inst` (according to its decorations), + /// and ensure that those capabilities have been detected and stored in the + /// target-specific extension tracker. + void handleRequiredCapabilities(IRInst* inst); + virtual void handleRequiredCapabilitiesImpl(IRInst* inst) { SLANG_UNUSED(inst); } + + static IRVarLayout* getVarLayout(IRInst* var); +}; + +} +#endif diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 84c369d40..f9d71beb9 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1450,17 +1450,6 @@ void CLikeSourceEmitter::_emitCallArgList(IRCall* inst) m_writer->emit(")"); } -void CLikeSourceEmitter::handleRequiredCapabilities(IRInst* inst) -{ - auto decoratedValue = inst; - while (auto specInst = as<IRSpecialize>(decoratedValue)) - { - decoratedValue = getSpecializedValue(specInst); - } - - handleRequiredCapabilitiesImpl(decoratedValue); -} - void CLikeSourceEmitter::emitCallExpr(IRCall* inst, EmitOpInfo outerPrec) { auto funcValue = inst->getOperand(0); @@ -2164,15 +2153,6 @@ void CLikeSourceEmitter::emitSemantics(IRInst* inst) emitSemanticsImpl(inst); } -IRVarLayout* CLikeSourceEmitter::getVarLayout(IRInst* var) -{ - auto decoration = var->findDecoration<IRLayoutDecoration>(); - if (!decoration) - return nullptr; - - return as<IRVarLayout>(decoration->getLayout()); -} - void CLikeSourceEmitter::emitLayoutSemantics(IRInst* inst, char const* uniformSemanticSpelling) { emitLayoutSemanticsImpl(inst, uniformSemanticSpelling); @@ -2781,35 +2761,6 @@ void CLikeSourceEmitter::emitParamTypeImpl(IRType* type, String const& name) emitType(type, name); } -IRInst* CLikeSourceEmitter::getSpecializedValue(IRSpecialize* specInst) -{ - auto base = specInst->getBase(); - - // It is possible to have a `specialize(...)` where the first - // operand is also a `specialize(...)`, so that we need to - // look at what declaration is being specialized at the inner - // step to find the one being specialized at the outer step. - // - while(auto baseSpecialize = as<IRSpecialize>(base)) - { - base = getSpecializedValue(baseSpecialize); - } - - auto baseGeneric = as<IRGeneric>(base); - if (!baseGeneric) - return base; - - auto lastBlock = baseGeneric->getLastBlock(); - if (!lastBlock) - return base; - - auto returnInst = as<IRReturnVal>(lastBlock->getTerminator()); - if (!returnInst) - return base; - - return returnInst->getVal(); -} - void CLikeSourceEmitter::emitFuncDecl(IRFunc* func) { // We don't want to emit declarations for operations diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 90db7476c..f699ed255 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -6,6 +6,7 @@ #include "slang-compiler.h" +#include "slang-emit-base.h" #include "slang-emit-precedence.h" #include "slang-emit-source-writer.h" @@ -16,7 +17,7 @@ namespace Slang { -class CLikeSourceEmitter: public RefObject +class CLikeSourceEmitter: public SourceEmitterBase { public: struct Desc @@ -292,8 +293,6 @@ public: void emitSemantics(IRInst* inst); void emitSemanticsUsingVarLayout(IRVarLayout* varLayout); - static IRVarLayout* getVarLayout(IRInst* var); - void emitLayoutSemantics(IRInst* inst, char const* uniformSemanticSpelling = "register"); // When we are about to traverse an edge from one block to another, @@ -323,8 +322,6 @@ public: void emitParamType(IRType* type, String const& name) { emitParamTypeImpl(type, name); } - IRInst* getSpecializedValue(IRSpecialize* specInst); - void emitFuncDecl(IRFunc* func); IREntryPointLayout* getEntryPointLayout(IRFunc* func); @@ -453,15 +450,8 @@ public: virtual bool tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType) { SLANG_UNUSED(varDecl); SLANG_UNUSED(varType); return false; } virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) { SLANG_UNUSED(inst); SLANG_UNUSED(inOuterPrec); return false; } - /// Inspect the capabilities required by `inst` (according to its decorations), - /// and ensure that those capabilities have been detected and stored in the - /// target-specific extension tracker. - void handleRequiredCapabilities(IRInst* inst); - virtual void handleRequiredCapabilitiesImpl(IRInst* inst) { SLANG_UNUSED(inst); } - virtual void emitPostKeywordTypeAttributesImpl(IRInst* inst) { SLANG_UNUSED(inst); } - void _emitArrayType(IRArrayType* arrayType, DeclaratorInfo* declarator); void _emitUnsizedArrayType(IRUnsizedArrayType* arrayType, DeclaratorInfo* declarator); void _emitType(IRType* type, DeclaratorInfo* declarator); diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index fe039feb0..37fd673ed 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -2,11 +2,14 @@ #include "slang-emit.h" #include "slang-compiler.h" +#include "slang-emit-base.h" + #include "slang-ir.h" #include "slang-ir-insts.h" - +#include "slang-ir-layout.h" +#include "slang-ir-spirv-snippet.h" +#include "slang-ir-spirv-legalize.h" #include "spirv/unified1/spirv.h" - #include "../core/slang-memory-arena.h" namespace Slang @@ -36,16 +39,7 @@ namespace Slang // [2.3: Physical Layout of a SPIR-V Module and Instruction] // // > A SPIR-V module is a single linear stream of words. -// -// [2.2: Terms] -// -// > Word: 32 bits. -// -// Despite the importance to SPIR-V, the `spirv.h` header doesn't -// define a type for words, so we'll do it here. - /// A SPIR-V word. -typedef uint32_t SpvWord; // [2.3: Physical Layout of a SPIR-V Module and Instruction] // @@ -268,6 +262,14 @@ void SpvInstParent::dumpTo(List<SpvWord>& ioWords) } } +/// The context for inlining a SPV assembly snippet. +struct SpvSnippetEmitContext +{ + SpvInst* resultType; + Dictionary<SpvStorageClass, IRInst*> qualifiedResultTypes; + List<SpvWord> argumentIds; +}; + // Now that we've defined the intermediate data structures we will // use to represent SPIR-V code during emission, we will move on // to defining the main context type that will drive SPIR-V @@ -275,10 +277,14 @@ void SpvInstParent::dumpTo(List<SpvWord>& ioWords) /// Context used for translating a Slang IR module to SPIR-V struct SPIRVEmitContext + : public SourceEmitterBase + , public SPIRVEmitSharedContext { /// The Slang IR module being translated IRModule* m_irModule; + DiagnosticSink* m_sink; + // [2.2: Terms] // // > <id>: A numerical name; the name used to refer to an object, a type, @@ -385,12 +391,36 @@ struct SPIRVEmitContext /// Map a Slang IR instruction to the corresponding SPIR-V instruction Dictionary<IRInst*, SpvInst*> m_mapIRInstToSpvInst; + // Sometimes we need to reserve an ID for an `IRInst` without actually + // emitting it. We use `m_mapIRInstToSpvID` to hold all reserved SpvIDs. + // Use `getIRInstSpvID` to obtain an SpvID for an `IRInst` if the + // `IRInst` may not have been emitted. + Dictionary<IRInst*, SpvWord> m_mapIRInstToSpvID; + /// Register that `irInst` maps to `spvInst` void registerInst(IRInst* irInst, SpvInst* spvInst) { m_mapIRInstToSpvInst.Add(irInst, spvInst); } + /// Get or reserve a SpvID for an IR value. + SpvWord getIRInstSpvID(IRInst* inst) + { + // If we have already emitted an SpvInst for `inst`, return its ID. + SpvInst* spvInst = nullptr; + if (m_mapIRInstToSpvInst.TryGetValue(inst, spvInst)) + return getID(spvInst); + // Check if we have reserved an ID for `inst`. + SpvWord result = 0; + if (m_mapIRInstToSpvID.TryGetValue(inst, result)) + return result; + // Otherwise, reserve a new ID for inst, and register it in `m_mapIRInstToSpvID`. + result = m_nextID; + ++m_nextID; + m_mapIRInstToSpvID[inst] = result; + return result; + } + // When we are emitting an instruction that can produce // a result, we will allocate an <id> to it so that other // instructions can refer to it. @@ -409,6 +439,18 @@ struct SPIRVEmitContext return id; } + struct VectorTypeKey + { + BaseType baseType; + IRIntegerValue elementCount; + HashCode getHashCode() { return combineHash((int)baseType, (HashCode)elementCount); } + bool operator==(const VectorTypeKey& other) + { + return baseType == other.baseType && elementCount == other.elementCount; + } + }; + Dictionary<VectorTypeKey, SpvInst*> m_vectorTypes; + // We will build up `SpvInst`s in a stateful fashion, // mostly for convenience. We could in theory compute // the number of words each instruction needs, then allocate @@ -467,6 +509,8 @@ struct SPIRVEmitContext if(irInst) { registerInst(irInst, spvInst); + // If we have reserved an SpvID for `irInst`, make sure to use it. + m_mapIRInstToSpvID.TryGetValue(irInst, spvInst->id); } // Set up the scope @@ -561,9 +605,6 @@ struct SPIRVEmitContext /// Emit an operand to the current instruction, which references `src` by its <id> void emitOperand(IRInst* src) { - // We first ensure that the `src` instruction has been emitted, - // and then handle it as for any other <id> operand. - // SpvInst* spvSrc = ensureInst(src); emitOperand(getID(spvSrc)); } @@ -629,6 +670,25 @@ struct SPIRVEmitContext emitOperand(getID(m_currentInst)); } + void emitOperand(SpvDecoration decoration) { emitOperand((SpvWord)decoration); } + + void emitOperand(SpvBuiltIn builtin) { emitOperand((SpvWord)builtin); } + void emitOperand(SpvStorageClass val) { emitOperand((SpvWord)val); } + + Dictionary<IRIntegerValue, SpvInst*> m_spvIntConstants; + SpvInst* emitConstant(IRIntegerValue val, IRType* type) + { + SpvInst* result = nullptr; + if (m_spvIntConstants.TryGetValue(val, result)) + return result; + return emitInst( + getSection(SpvLogicalSectionID::Constants), + nullptr, + SpvOpConstant, + type, + kResultID, + (SpvWord)val); + } // As another convenience, there are often cases where // we will want to emit all of the operands of some // IR instruction as <id> operands of a SPIR-V @@ -742,6 +802,16 @@ struct SPIRVEmitContext return spvInst; } + template<typename OperandEmitFunc> + SpvInst* emitInstCustomOperandFunc(SpvInstParent* parent, IRInst* irInst, SpvOp opcode, const OperandEmitFunc& f) + { + InstConstructScope scopeInst(this, opcode, irInst); + SpvInst* spvInst = scopeInst; + f(); + parent->addInst(spvInst); + return spvInst; + } + // Now that we've gotten the core infrastructure out of the way, // let's start looking at emitting some instructions that make // up a SPIR-V module. @@ -826,14 +896,110 @@ struct SPIRVEmitContext CASE(kIROp_DoubleType, 64); #undef CASE - - // > OpTypeVector - // > OpTypeMatrix + case kIROp_PtrType: + case kIROp_RefType: + case kIROp_OutType: + case kIROp_InOutType: + { + SpvStorageClass storageClass = SpvStorageClassFunction; + auto ptrType = as<IRPtrTypeBase>(inst); + if (ptrType->hasAddressSpace()) + storageClass = (SpvStorageClass)ptrType->getAddressSpace(); + return emitInst( + getSection(SpvLogicalSectionID::Types), + inst, + SpvOpTypePointer, + kResultID, + storageClass, + inst->getOperand(0)); + } + case kIROp_StructType: + { + return emitInstCustomOperandFunc( + getSection(SpvLogicalSectionID::Types), inst, SpvOpTypeStruct, [&]() { + emitOperand(kResultID); + for (auto field : static_cast<IRStructType*>(inst)->getFields()) + { + emitOperand(field->getFieldType()); + // TODO: decorate offset + } + }); + } + case kIROp_VectorType: + { + auto vectorType = static_cast<IRVectorType*>(inst); + return ensureVectorType( + static_cast<IRBasicType*>(vectorType->getElementType())->getBaseType(), + static_cast<IRIntLit*>(vectorType->getElementCount())->getValue(), + vectorType); + } + case kIROp_MatrixType: + { + auto matrixType = static_cast<IRMatrixType*>(inst); + auto vectorSpvType = ensureVectorType( + static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(), + static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(), + nullptr); + auto matrixSPVType = emitInst( + getSection(SpvLogicalSectionID::Types), + inst, + SpvOpTypeMatrix, + kResultID, + vectorSpvType, + (SpvWord)static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue()); + // TODO: properly compute matrix stride. + auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(); + uint32_t stride = 0; + switch (columnCount) + { + case 1: + stride = 4; + break; + case 2: + stride = 8; + break; + case 3: + case 4: + stride = 16; + break; + default: + break; + } + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + matrixSPVType, + SpvDecorationRowMajor, + SpvDecorationMatrixStride, + stride); + return matrixSPVType; + } + case kIROp_UnsizedArrayType: + { + auto elementType = static_cast<IRUnsizedArrayType*>(inst)->getElementType(); + auto runtimeArrayType = emitInst( + getSection(SpvLogicalSectionID::Types), + nullptr, + SpvOpTypeRuntimeArray, + kResultID, + elementType); + // TODO: properly decorate stride. + IRSizeAndAlignment sizeAndAlignment; + getNaturalSizeAndAlignment(this->m_targetRequest, elementType, &sizeAndAlignment); + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + runtimeArrayType, + SpvDecorationArrayStride, + (SpvWord)sizeAndAlignment.getStride()); + return runtimeArrayType; + } // > OpTypeImage // > OpTypeSampler // > OpTypeArray // > OpTypeRuntimeArray - // > OpTypeStruct // > OpTypeOpaque // > OpTypePointer @@ -858,6 +1024,15 @@ struct SPIRVEmitContext // return emitFunc(as<IRFunc>(inst)); + case kIROp_BoolLit: + case kIROp_IntLit: + case kIROp_FloatLit: + return emitLit(inst); + + case kIROp_GlobalParam: + return emitGlobalParam(as<IRGlobalParam>(inst)); + case kIROp_GlobalVar: + return emitGlobalVar(as<IRGlobalVar>(inst)); // ... default: @@ -866,6 +1041,162 @@ struct SPIRVEmitContext } } + // Ensures an SpvInst for the specified vector type is emitted. + // `inst` represents an optional `IRVectorType` inst representing the vector type, if + // it is nullptr, this function will create one. + SpvInst* ensureVectorType(BaseType baseType, IRIntegerValue elementCount, IRVectorType* inst) + { + VectorTypeKey key = {baseType, elementCount}; + SpvInst* result = nullptr; + if (m_vectorTypes.TryGetValue(key, result)) + return result; + if (!inst) + { + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertInto(m_irModule->getModuleInst()); + inst = builder.getVectorType( + builder.getBasicType(baseType), + builder.getIntValue(builder.getIntType(), elementCount)); + } + result = emitInst( + getSection(SpvLogicalSectionID::Types), + inst, + SpvOpTypeVector, + kResultID, + inst->getElementType(), + (SpvWord)elementCount); + m_vectorTypes[key] = result; + return result; + } + + void emitVarLayout(SpvInst* varInst, IRVarLayout* layout) + { + for (auto rr : layout->getOffsetAttrs()) + { + UInt index = rr->getOffset(); + UInt space = rr->getSpace(); + switch (rr->getResourceKind()) + { + case LayoutResourceKind::Uniform: + break; + + case LayoutResourceKind::VaryingInput: + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationLocation, + (SpvWord)index); + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationIndex, + (SpvWord)space); + break; + case LayoutResourceKind::VaryingOutput: + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationLocation, + (SpvWord)index); + if (space) + { + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationIndex, + (SpvWord)space); + } + break; + + case LayoutResourceKind::SpecializationConstant: + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationSpecId, + (SpvWord)index); + break; + + case LayoutResourceKind::ConstantBuffer: + case LayoutResourceKind::ShaderResource: + case LayoutResourceKind::UnorderedAccess: + case LayoutResourceKind::SamplerState: + case LayoutResourceKind::DescriptorTableSlot: + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationBinding, + (SpvWord)index); + if (space) + { + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationDescriptorSet, + (SpvWord)space); + } + break; + default: + break; + } + } + } + /// Emit a global parameter definition. + SpvInst* emitGlobalParam(IRGlobalParam* param) + { + auto layout = getVarLayout(param); + auto storageClass = SpvStorageClassUniform; + if (auto ptrType = as<IRPtrTypeBase>(param->getDataType())) + { + if (ptrType->hasAddressSpace()) + storageClass = (SpvStorageClass)ptrType->getAddressSpace(); + } + auto varInst = emitInst( + getSection(SpvLogicalSectionID::GlobalVariables), + param, + SpvOpVariable, + param->getDataType(), + kResultID, + storageClass); + emitVarLayout(varInst, layout); + return varInst; + } + + /// Emit a global variable definition. + SpvInst* emitGlobalVar(IRGlobalVar* globalVar) + { + auto layout = getVarLayout(globalVar); + auto storageClass = SpvStorageClassUniform; + if (auto ptrType = as<IRPtrTypeBase>(globalVar->getDataType())) + { + if (ptrType->hasAddressSpace()) + storageClass = (SpvStorageClass)ptrType->getAddressSpace(); + } + auto varInst = emitInst( + getSection(SpvLogicalSectionID::GlobalVariables), + globalVar, + SpvOpVariable, + globalVar->getDataType(), + kResultID, + storageClass); + emitVarLayout(varInst, layout); + return varInst; + } + /// Emit the given `irFunc` to SPIR-V SpvInst* emitFunc(IRFunc* irFunc) { @@ -951,9 +1282,7 @@ struct SPIRVEmitContext // for( auto irParam : irFunc->getParams() ) { - emitInst(spvFunc, irParam, SpvOpFunctionParameter, - irParam->getFullType(), - kResultID); + emitParam(spvFunc, irParam); } // [3.32.17. Control-Flow Instructions] @@ -992,11 +1321,13 @@ struct SPIRVEmitContext // [3.32.17. Control-Flow Instructions] // // > OpPhi - // - // TODO: We eventually need to emit `OpPhi` instructions corresponding - // to the parameters of any non-entry block, with operands representing - // the values passed along incoming edges from the predecessor blocks. - + if (irBlock != irFunc->getFirstBlock()) + { + for (auto irParam : irBlock->getParams()) + { + emitPhi(spvBlock, irParam); + } + } for( auto irInst : irBlock->getOrdinaryInsts() ) { // Any instructions local to the block will be emitted as children @@ -1036,16 +1367,243 @@ struct SPIRVEmitContext /// Emit an instruction that is local to the body of the given `parent`. SpvInst* emitLocalInst(SpvInstParent* parent, IRInst* inst) { + auto getBlockID = [=](IRBlock* block) + { + SpvInst* spvInst = nullptr; + m_mapIRInstToSpvInst.TryGetValue(block, spvInst); + SLANG_ASSERT(spvInst); + return getID(spvInst); + }; switch( inst->getOp() ) { default: SLANG_UNIMPLEMENTED_X("unhandled instruction opcode"); break; + case kIROp_Specialize: + return nullptr; + case kIROp_Var: + return emitVar(parent, inst); + case kIROp_Call: + return emitCall(parent, inst); + case kIROp_FieldAddress: + return emitFieldAddress(parent, as<IRFieldAddress>(inst)); + case kIROp_FieldExtract: + return emitFieldExtract(parent, as<IRFieldExtract>(inst)); + case kIROp_getElementPtr: + return emitGetElementPtr(parent, as<IRGetElementPtr>(inst)); + case kIROp_getElement: + return emitGetElement(parent, as<IRGetElement>(inst)); + case kIROp_Load: + return emitLoad(parent, as<IRLoad>(inst)); + case kIROp_Store: + return emitStore(parent, as<IRStore>(inst)); + case kIROp_swizzle: + return emitSwizzle(parent, as<IRSwizzle>(inst)); + case kIROp_Construct: + return emitConstruct(parent, inst); + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_IRem: + case kIROp_FRem: + case kIROp_Neg: + case kIROp_Not: + case kIROp_And: + case kIROp_Or: + case kIROp_BitNot: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_Less: + case kIROp_Leq: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Greater: + case kIROp_Geq: + case kIROp_Rsh: + case kIROp_Lsh: + return emitArithmetic(parent, inst); + case kIROp_ReturnVal: + return emitInst( + parent, inst, SpvOpReturnValue, as<IRReturnVal>(inst)->getVal()); + case kIROp_ReturnVoid: + return emitInst(parent, inst, SpvOpReturn); + case kIROp_discard: + return emitInst(parent, inst, SpvOpKill); + case kIROp_unconditionalBranch: + return emitInst( + parent, + inst, + SpvOpBranch, + getBlockID(as<IRUnconditionalBranch>(inst)->getTargetBlock())); + case kIROp_loop: + { + auto loopInst = as<IRLoop>(inst); + + SpvWord loopControl = 0; + if (auto loopControlDecoration = + loopInst->findDecoration<IRLoopControlDecoration>()) + { + switch (loopControlDecoration->getMode()) + { + case IRLoopControl::kIRLoopControl_Unroll: + loopControl = 0x1; + break; + case IRLoopControl::kIRLoopControl_Loop: + loopControl = 0x2; + break; + default: + break; + } + } + emitInst( + parent, + nullptr, + SpvOpLoopMerge, + getBlockID(loopInst->getBreakBlock()), + getBlockID(loopInst->getContinueBlock()), + loopControl); + + return emitInst(parent, inst, SpvOpBranch, loopInst->getTargetBlock()); + } + case kIROp_ifElse: + { + auto ifelseInst = as<IRIfElse>(inst); + auto afterBlockID = getBlockID(ifelseInst->getAfterBlock()); + emitInst( + parent, + nullptr, + SpvOpSelectionMerge, + afterBlockID); + auto falseLabel = ifelseInst->getFalseBlock(); + return emitInst( + parent, + inst, + SpvOpBranchConditional, + ifelseInst->getCondition(), + ifelseInst->getTrueBlock(), + falseLabel ? getID(ensureInst(falseLabel)) : afterBlockID); + } + case kIROp_Switch: + { + auto switchInst = as<IRSwitch>(inst); + auto mergeBlockID = getBlockID(switchInst->getBreakLabel()); + emitInst( + parent, + nullptr, + SpvOpSelectionMerge, mergeBlockID); + return emitInstCustomOperandFunc(parent, inst, SpvOpSwitch, [&]() { + emitOperand(switchInst->getCondition()); + auto defaultLabel = switchInst->getDefaultLabel(); + emitOperand(defaultLabel ? getID(ensureInst(defaultLabel)) : mergeBlockID); + for (UInt c = 0; c < switchInst->getCaseCount(); c++) + { + auto value = switchInst->getCaseValue(c); + auto intLit = as<IRIntLit>(value); + SLANG_ASSERT(intLit); + emitOperand((SpvWord)intLit->getValue()); + auto caseLabel = switchInst->getCaseLabel(c); + emitOperand(caseLabel ? getID(ensureInst(caseLabel)) : mergeBlockID); + } + }); + } + case kIROp_Unreachable: + return emitInst(parent, inst, SpvOpUnreachable); + case kIROp_conditionalBranch: + SLANG_UNEXPECTED("Unstructured branching is not supported by SPIRV."); + } + } - // [3.32.17. Control-Flow Instructions] - // - // > OpReturn - case kIROp_ReturnVoid: return emitInst(parent, inst, SpvOpReturn); + SpvInst* emitLit(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_IntLit: + { + auto value = as<IRIntLit>(inst)->getValue(); + switch (as<IRBasicType>(inst->getDataType())->getBaseType()) + { + case BaseType::Int64: + case BaseType::UInt64: + return emitInst( + getSection(SpvLogicalSectionID::Constants), + inst, + SpvOpConstant, + inst->getDataType(), + kResultID, + (SpvWord)(value & 0xFFFFFFFF), + (SpvWord)((value >> 32) & 0xFFFFFFFF)); + default: + return emitInst( + getSection(SpvLogicalSectionID::Constants), + inst, + SpvOpConstant, + inst->getDataType(), + kResultID, + (SpvWord)value); + } + } + case kIROp_FloatLit: + { + auto value = as<IRConstant>(inst)->value.floatVal; + switch (as<IRBasicType>(inst->getDataType())->getBaseType()) + { + case BaseType::Half: + return emitInst( + getSection(SpvLogicalSectionID::Constants), + inst, + SpvOpConstant, + inst->getDataType(), + kResultID, + (SpvWord)(FloatToHalf((float)value))); + case BaseType::Float: + return emitInst( + getSection(SpvLogicalSectionID::Constants), + inst, + SpvOpConstant, + inst->getDataType(), + kResultID, + (SpvWord)(FloatAsInt((float)value))); + case BaseType::Double: + { + auto ival = DoubleAsInt64(value); + return emitInst( + getSection(SpvLogicalSectionID::Constants), + inst, + SpvOpConstant, + inst->getDataType(), + kResultID, + (SpvWord)(ival&0xFFFFFFFF), + (SpvWord)(ival>>32)); + } + default: + return nullptr; + } + } + case kIROp_BoolLit: + { + if (as<IRBoolLit>(inst)->getValue()) + { + return emitInst( + getSection(SpvLogicalSectionID::Constants), + inst, + SpvOpConstantTrue, + inst->getDataType(), + kResultID); + } + else + { + return emitInst( + getSection(SpvLogicalSectionID::Constants), + inst, + SpvOpConstantFalse, + inst->getDataType(), + kResultID); + } + } + default: + return nullptr; } } @@ -1184,24 +1742,655 @@ struct SPIRVEmitContext } } - SPIRVEmitContext(IRModule* module) : - m_irModule(module), - m_memoryArena(2048) + SpvInst* emitBuiltinSystemVal(SpvInstParent* parent, IRInst* inst, SpvBuiltIn builtinVal) + { + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertBefore(inst); + + auto ptrIRType = builder.getPtrType(inst->getDataType()); + auto varInst = emitInst(parent, inst, SpvOpVariable, ptrIRType, kResultID); + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationBuiltIn, + builtinVal); + return varInst; + } + + SpvInst* emitParam(SpvInstParent* parent, IRInst* inst) + { + if (auto layout = getVarLayout(inst)) + { + if (auto systemValueAttr = layout->findAttr<IRSystemValueSemanticAttr>()) + { + String semanticName = systemValueAttr->getName(); + semanticName = semanticName.toLower(); + if (semanticName == "sv_dispatchthreadid") + { + return emitBuiltinSystemVal(parent, inst, SpvBuiltInGlobalInvocationId); + } + } + } + return emitInst(parent, inst, SpvOpFunctionParameter, inst->getFullType(), kResultID); + } + + SpvInst* emitVar(SpvInstParent* parent, IRInst* inst) + { + SpvWord storageClass = SpvStorageClassFunction; + auto rate = inst->getFullType()->getRate(); + if (rate) + { + switch (rate->getOp()) + { + case kIROp_GroupSharedRate: + storageClass = SpvStorageClassWorkgroup; + break; + default: + break; + } + } + return emitInst(parent, inst, SpvOpVariable, inst->getFullType(), kResultID, storageClass); + } + + /// Cached `IRParam` indices in an `IRBlock`. For use in `getParamIndexInBlock`. + struct BlockParamIndexInfo : public RefObject + { + Dictionary<IRParam*, int> mapParamToIndex; + }; + Dictionary<IRBlock*, RefPtr<BlockParamIndexInfo>> m_mapIRBlockToParamIndexInfo; + + /// Returns the index of an `IRParam` inside a `IRBlock`. + /// The results are cached in `m_mapIRBlockToParamIndexInfo` to avoid linear search. + int getParamIndexInBlock(IRBlock* block, IRParam* paramInst) + { + RefPtr<BlockParamIndexInfo> info; + int result = -1; + if (m_mapIRBlockToParamIndexInfo.TryGetValue(block, info)) + { + info->mapParamToIndex.TryGetValue(paramInst, result); + SLANG_ASSERT(result != -1); + return result; + } + info = new BlockParamIndexInfo(); + int paramIndex = 0; + for (auto param : block->getParams()) + { + info->mapParamToIndex[param] = paramIndex; + if (param == paramInst) + result = paramIndex; + paramIndex++; + } + m_mapIRBlockToParamIndexInfo[block] = info; + SLANG_ASSERT(result != -1); + return result; + } + + SpvInst* emitPhi(SpvInstParent* parent, IRParam* inst) + { + // An `IRParam` in an ordinary `IRBlock` represents a phi value. + // We can translate them directly to SPIRV's `Phi` instruction. + // In order to do that, we need to figure out the source values + // of this `IRParam`, which can be done by looking at the users + // of current `IRBlock`. + + // First, we find the index of this param. + IRBlock* block = as<IRBlock>(inst->getParent()); + SLANG_ASSERT(block); + int paramIndex = getParamIndexInBlock(block, inst); + + // Emit a Phi instruction. + return emitInstCustomOperandFunc(parent, inst, SpvOpPhi, [&]() { + emitOperand(inst->getFullType()); + emitOperand(kResultID); + // Find phi arguments from incoming branch instructions that target `block`. + for (auto use = block->firstUse; use; use = use->nextUse) + { + auto branchInst = use->getUser(); + UInt argStartIndex = 0; + switch (branchInst->getOp()) + { + case kIROp_unconditionalBranch: + argStartIndex = 1; + break; + case kIROp_loop: + argStartIndex = 3; + break; + default: + // A phi argument can only come from an unconditional branch inst. + // Other uses are not relavent so we should skip. + continue; + } + SLANG_ASSERT(argStartIndex + paramIndex < branchInst->getOperandCount()); + auto valueInst = branchInst->getOperand(argStartIndex + paramIndex); + emitOperand(valueInst); + auto sourceBlock = as<IRBlock>(branchInst->getParent()); + SLANG_ASSERT(sourceBlock); + emitOperand(getIRInstSpvID(sourceBlock)); + } + }); + } + + SpvInst* emitCall(SpvInstParent* parent, IRInst* inst) + { + auto funcValue = inst->getOperand(0); + + // Does this function declare any requirements. + handleRequiredCapabilities(funcValue); + + // We want to detect any call to an intrinsic operation, and inline + // the SPIRV snippet directly at the call site. + if (auto targetIntrinsic = Slang::findBestTargetIntrinsicDecoration( + funcValue, m_targetRequest->getTargetCaps())) + { + return emitIntrinsicCallExpr(parent, static_cast<IRCall*>(inst), targetIntrinsic); + } + else + { + return emitInst( + parent, inst, SpvOpFunctionCall, inst->getFullType(), kResultID, OperandsOf(inst)); + } + } + + SpvInst* emitIntrinsicCallExpr( + SpvInstParent* parent, + IRCall* inst, + IRTargetIntrinsicDecoration* intrinsic) + { + SpvSnippet* snippet = getParsedSpvSnippet(intrinsic); + SpvSnippetEmitContext context; + context.resultType = ensureInst(inst->getFullType()); + for (SlangUInt i = 0; i < inst->getArgCount(); i++) + { + auto argInst = ensureInst(inst->getArg(i)); + if (argInst) + { + context.argumentIds.add(getID(argInst)); + } + else + { + context.argumentIds.add(0xFFFFFFFF); + } + } + // A SPIRV snippet may refer to the result type of this inst with a + // different storage-class qualifier. We need to pre-create these + // storage-class-qualified result pointer types so they can be used + // during inlining of the snippet. + if (auto oldPtrType = as<IRPtrTypeBase>(inst->getDataType())) + { + for (auto storageClass : snippet->usedResultTypeStorageClasses) + { + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertBefore(inst); + auto newPtrType = builder.getPtrType( + oldPtrType->getOp(), oldPtrType->getValueType(), storageClass); + context.qualifiedResultTypes[storageClass] = newPtrType; + } + } + return emitSpvSnippet(parent, inst, context, snippet); + } + + SpvInst* emitSpvSnippet( + SpvInstParent* parent, + IRCall* inst, + const SpvSnippetEmitContext& context, + SpvSnippet* snippet) + { + ShortList<SpvInst*> emittedInsts; + for (Index i = 0; i < snippet->instructions.getCount(); i++) + { + auto& spvSnippetInst = snippet->instructions[i]; + InstConstructScope scopeInst(this, (SpvOp)spvSnippetInst.opCode, nullptr); + SpvInst* spvInst = scopeInst; + for (auto operand : spvSnippetInst.operands) + { + switch (operand.type) + { + case SpvSnippet::ASMOperandType::SpvWord: + emitOperand((SpvWord)operand.content); + break; + case SpvSnippet::ASMOperandType::ObjectReference: + SLANG_ASSERT( + operand.content >= 0 && operand.content < context.argumentIds.getCount()); + emitOperand(context.argumentIds[operand.content]); + break; + case SpvSnippet::ASMOperandType::ResultId: + emitOperand(kResultID); + break; + case SpvSnippet::ASMOperandType::ResultTypeId: + if (operand.content != -1) + { + emitOperand(context.qualifiedResultTypes[(SpvStorageClass)operand.content] + .GetValue()); + } + else + { + emitOperand(context.resultType); + } + break; + case SpvSnippet::ASMOperandType::InstReference: + SLANG_ASSERT(operand.content >= 0 && operand.content < emittedInsts.getCount()); + emitOperand(getID(emittedInsts[operand.content])); + break; + } + } + parent->addInst(spvInst); + emittedInsts.add(spvInst); + } + auto resultInst = emittedInsts.getLast(); + registerInst(inst, resultInst); + return resultInst; + } + + struct StructTypeInfo : public RefObject + { + Dictionary<IRStructKey*, Index> structFieldIndices; + }; + + Dictionary<IRStructType*, RefPtr<StructTypeInfo>> m_structTypeInfos; + + RefPtr<StructTypeInfo> createStructTypeInfo(IRStructType* structType) + { + RefPtr<StructTypeInfo> typeInfo = new StructTypeInfo(); + Index index = 0; + for (auto field : structType->getFields()) + { + typeInfo->structFieldIndices[field->getKey()] = index; + index++; + } + return typeInfo; + } + Index getStructFieldId(IRStructType* structType, IRStructKey* structFieldKey) + { + RefPtr<StructTypeInfo> info; + if (!m_structTypeInfos.TryGetValue(structType, info)) + { + info = createStructTypeInfo(structType); + m_structTypeInfos[structType] = info; + } + Index fieldIndex = -1; + info->structFieldIndices.TryGetValue(structFieldKey, fieldIndex); + SLANG_ASSERT(fieldIndex != -1); + return fieldIndex; + } + + SpvInst* emitFieldAddress(SpvInstParent* parent, IRFieldAddress* fieldAddress) + { + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertBefore(fieldAddress); + + auto base = fieldAddress->getBase(); + SpvWord baseId = 0; + IRStructType* baseStructType = nullptr; + + if (auto ptrLikeType = as<IRPointerLikeType>(base->getDataType())) + { + baseStructType = as<IRStructType>(ptrLikeType->getElementType()); + baseId = getID(ensureInst(base)); + } + else if (auto ptrType = as<IRPtrTypeBase>(base->getDataType())) + { + baseStructType = as<IRStructType>(ptrType->getValueType()); + baseId = getID(ensureInst(base)); + } + else + { + baseStructType = as<IRStructType>(base->getDataType()); + + auto structPtrType = builder.getPtrType(baseStructType); + auto varInst = emitInst( + parent, nullptr, SpvOpVariable, structPtrType, kResultID, SpvStorageClassFunction); + emitInst(parent, nullptr, SpvOpStore, varInst, base); + baseId = getID(varInst); + } + SLANG_ASSERT(baseStructType && "field_address require base to be a struct."); + auto fieldId = emitConstant( + getStructFieldId(baseStructType, as<IRStructKey>(fieldAddress->getField())), + builder.getIntType()); + return emitInst( + parent, + fieldAddress, + SpvOpAccessChain, + fieldAddress->getFullType(), + kResultID, + baseId, + fieldId); + } + + SpvInst* emitFieldExtract(SpvInstParent* parent, IRFieldExtract* inst) + { + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertBefore(inst); + + IRStructType* baseStructType = as<IRStructType>(inst->getBase()->getDataType()); + SLANG_ASSERT(baseStructType && "field_extract require base to be a struct."); + auto fieldId = emitConstant( + getStructFieldId(baseStructType, as<IRStructKey>(inst->getField())), + builder.getIntType()); + + return emitInst( + parent, + inst, + SpvOpCompositeExtract, + inst->getDataType(), + kResultID, + inst->getBase(), + fieldId); + } + + SpvInst* emitGetElementPtr(SpvInstParent* parent, IRGetElementPtr* inst) + { + auto base = inst->getBase(); + SpvWord baseId = 0; + IRArrayType* baseArrayType = nullptr; + + if (auto ptrLikeType = as<IRPointerLikeType>(base->getDataType())) + { + baseArrayType = as<IRArrayType>(ptrLikeType->getElementType()); + baseId = getID(ensureInst(base)); + } + else if (auto ptrType = as<IRPtrTypeBase>(base->getDataType())) + { + baseArrayType = as<IRArrayType>(ptrType->getValueType()); + baseId = getID(ensureInst(base)); + } + else + { + SLANG_ASSERT(!"invalid IR: base of getElementPtr must be a pointer."); + } + SLANG_ASSERT(baseArrayType && "getElementPtr require base to be an array."); + return emitInst( + parent, + inst, + SpvOpAccessChain, + inst->getFullType(), + kResultID, + baseId, + inst->getIndex()); + } + + SpvInst* emitGetElement(SpvInstParent* parent, IRGetElement* inst) + { + auto base = inst->getBase(); + SpvWord baseId = 0; + IRArrayType* baseArrayType = nullptr; + + if (auto ptrLikeType = as<IRPointerLikeType>(base->getDataType())) + { + baseArrayType = as<IRArrayType>(ptrLikeType->getElementType()); + baseId = getID(ensureInst(base)); + } + else if (auto ptrType = as<IRPtrTypeBase>(base->getDataType())) + { + baseArrayType = as<IRArrayType>(ptrType->getValueType()); + baseId = getID(ensureInst(base)); + } + else + { + SLANG_ASSERT(!"invalid IR: base of getElement must be a pointer."); + } + SLANG_ASSERT(baseArrayType && "getElement require base to be an array."); + + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertBefore(inst); + + auto ptr = emitInst( + parent, + nullptr, + SpvOpAccessChain, + builder.getPtrType(inst->getFullType()), + kResultID, + baseId, + inst->getIndex()); + return emitInst(parent, inst, SpvOpLoad, inst->getFullType(), kResultID, ptr); + } + + SpvInst* emitLoad(SpvInstParent* parent, IRLoad* inst) + { + return emitInst(parent, inst, SpvOpLoad, inst->getDataType(), kResultID, inst->getPtr()); + } + + SpvInst* emitStore(SpvInstParent* parent, IRStore* inst) + { + return emitInst(parent, inst, SpvOpStore, inst->getPtr(), inst->getVal()); + } + + SpvInst* emitSwizzle(SpvInstParent* parent, IRSwizzle* inst) + { + return emitInstCustomOperandFunc(parent, inst, SpvOpVectorShuffle, [&]() { + emitOperand(inst->getDataType()); + emitOperand(kResultID); + emitOperand(inst->getBase()); + emitOperand(inst->getBase()); + for (UInt i = 0; i < inst->getElementCount(); i++) + { + auto index = as<IRIntLit>(inst->getElementIndex(i)); + emitOperand((SpvWord)index->getValue()); + } + }); + } + + SpvInst* emitConstruct(SpvInstParent* parent, IRInst* inst) + { + if (as<IRBasicType>(inst->getDataType())) + { + if (inst->getOperandCount() == 1) + { + if (inst->getDataType() == inst->getOperand(0)->getDataType()) + return emitInst(parent, inst, SpvOpCopyObject, kResultID, inst->getOperand(0)); + else + return emitInst(parent, inst, SpvOpBitcast, inst->getDataType(), kResultID, inst->getOperand(0)); + } + else + { + SLANG_ASSERT(!"spirv emit: unsupported Construct inst."); + return nullptr; + } + } + else + { + return emitInst( + parent, + inst, + SpvOpCompositeConstruct, + inst->getDataType(), + kResultID, + OperandsOf(inst)); + } + } + + bool isSignedType(IRBasicType* basicType) + { + switch (basicType->getBaseType()) + { + case BaseType::Float: + case BaseType::Double: + return true; + case BaseType::Int: + case BaseType::Int16: + case BaseType::Int64: + case BaseType::Int8: + return true; + default: + return false; + } + } + + SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) + { + IRType* elementType = inst->getDataType(); + if (auto vectorType = as<IRVectorType>(inst->getDataType())) + { + elementType = vectorType->getElementType(); + } + else if (auto matrixType = as<IRMatrixType>(inst->getDataType())) + { + //TODO: implement. + SLANG_ASSERT(!"unimplemented: matrix arithemetic"); + } + IRBasicType* basicType = as<IRBasicType>(elementType); + bool isFloatingPoint = false; + bool isBool = false; + switch (basicType->getBaseType()) + { + case BaseType::Float: + case BaseType::Double: + isFloatingPoint = true; + break; + case BaseType::Bool: + isBool = true; + default: + break; + } + SpvOp opCode = SpvOpUndef; + bool isSigned = isSignedType(basicType); + switch (inst->getOp()) + { + case kIROp_Add: + opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd; + break; + case kIROp_Sub: + opCode = isFloatingPoint ? SpvOpFSub : SpvOpISub; + break; + case kIROp_Mul: + opCode = isFloatingPoint ? SpvOpFMul : SpvOpIMul; + break; + case kIROp_Div: + opCode = isFloatingPoint ? SpvOpFDiv : isSigned ? SpvOpSDiv : SpvOpUDiv; + break; + case kIROp_IRem: + opCode = isSigned ? SpvOpSRem : SpvOpUMod; + break; + case kIROp_FRem: + opCode = SpvOpFRem; + break; + case kIROp_Less: + opCode = isFloatingPoint ? SpvOpFOrdLessThan + : isSigned ? SpvOpSLessThan : SpvOpULessThan; + break; + case kIROp_Leq: + opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual + : isSigned ? SpvOpSLessThanEqual : SpvOpULessThanEqual; + break; + case kIROp_Eql: + opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual; + break; + case kIROp_Neq: + opCode = isFloatingPoint ? SpvOpFOrdNotEqual + : isBool ? SpvOpLogicalNotEqual : SpvOpINotEqual; + break; + case kIROp_Geq: + opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual + : isSigned ? SpvOpSGreaterThanEqual : SpvOpUGreaterThanEqual; + break; + case kIROp_Greater: + opCode = isFloatingPoint ? SpvOpFOrdGreaterThan + : isSigned ? SpvOpSGreaterThan : SpvOpUGreaterThan; + break; + case kIROp_Neg: + opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate; + break; + case kIROp_And: + opCode = SpvOpLogicalAnd; + break; + case kIROp_Or: + opCode = SpvOpLogicalOr; + break; + case kIROp_Not: + opCode = SpvOpLogicalNot; + break; + case kIROp_BitAnd: + opCode = SpvOpBitwiseAnd; + break; + case kIROp_BitOr: + opCode = SpvOpBitwiseOr; + break; + case kIROp_BitXor: + opCode = SpvOpBitwiseXor; + break; + case kIROp_BitNot: + opCode = SpvOpBitReverse; + break; + case kIROp_Rsh: + opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical; + break; + case kIROp_Lsh: + opCode = SpvOpShiftLeftLogical; + break; + default: + SLANG_ASSERT(!"unknown arithmetic opcode"); + break; + } + return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, OperandsOf(inst)); + } + + OrderedHashSet<SpvCapability> m_capabilities; + + void requireSPIRVCapability(SpvCapability capability) + { + if (m_capabilities.Add(capability)) + { + emitInst( + getSection(SpvLogicalSectionID::Capabilities), + nullptr, + SpvOpCapability, + capability); + } + } + + void handleRequiredCapabilitiesImpl(IRInst* inst) + { + // TODO: declare required SPV capabilities. + + for (auto decoration : inst->getDecorations()) + { + switch (decoration->getOp()) + { + default: + break; + + case kIROp_RequireGLSLExtensionDecoration: + { + break; + } + case kIROp_RequireGLSLVersionDecoration: + { + break; + } + case kIROp_RequireSPIRVVersionDecoration: + { + break; + } + } + } + } + + SPIRVEmitContext(IRModule* module, TargetRequest* target, DiagnosticSink* sink) + : SPIRVEmitSharedContext(module, target) + , m_irModule(module) + , m_sink(sink) + , m_memoryArena(2048) { } }; SlangResult emitSPIRVFromIR( BackEndCompileRequest* compileRequest, + TargetRequest* targetRequest, IRModule* irModule, const List<IRFunc*>& irEntryPoints, List<uint8_t>& spirvOut) { - SLANG_UNUSED(compileRequest); - spirvOut.clear(); - SPIRVEmitContext context(irModule); + SPIRVEmitContext context(irModule, targetRequest, compileRequest->getSink()); + legalizeIRForSPIRV(&context, irModule, compileRequest->getSink()); context.emitFrontMatter(); for (auto irEntryPoint : irEntryPoints) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 352a27746..3da19cef1 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -916,6 +916,7 @@ SlangResult emitEntryPointsSourceFromIR( SlangResult emitSPIRVFromIR( BackEndCompileRequest* compileRequest, + TargetRequest* targetRequest, IRModule* irModule, const List<IRFunc*>& irEntryPoints, List<uint8_t>& spirvOut); @@ -947,11 +948,7 @@ SlangResult emitSPIRVForEntryPointsDirectly( auto irModule = linkedIR.module; auto irEntryPoints = linkedIR.entryPoints; - emitSPIRVFromIR( - compileRequest, - irModule, - irEntryPoints, - spirvOut); + emitSPIRVFromIR(compileRequest, targetRequest, irModule, irEntryPoints, spirvOut); return SLANG_OK; } diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 267866b1b..25313d2f5 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1904,6 +1904,7 @@ struct IRBuilder IRInOutType* getInOutType(IRType* valueType); IRRefType* getRefType(IRType* valueType); IRPtrTypeBase* getPtrType(IROp op, IRType* valueType); + IRPtrType* getPtrType(IROp op, IRType* valueType, IRIntegerValue addressSpace); IRArrayTypeBase* getArrayTypeBase( IROp op, @@ -2734,6 +2735,14 @@ IRTargetSpecificDecoration* findBestTargetDecoration( IRInst* val, CapabilityAtom targetCapabilityAtom); +inline IRTargetIntrinsicDecoration* findBestTargetIntrinsicDecoration( + IRInst* inInst, + CapabilitySet const& targetCaps) +{ + return as<IRTargetIntrinsicDecoration>(findBestTargetDecoration(inInst, targetCaps)); +} + + } #endif diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp new file mode 100644 index 000000000..f7fc53bdb --- /dev/null +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -0,0 +1,258 @@ +// slang-ir-spirv-legalize.cpp +#include "slang-ir-spirv-legalize.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-emit-base.h" +#include "slang-glsl-extension-tracker.h" + +namespace Slang +{ + +// +// Legalization of IR for direct SPIRV emit. +// + +struct StorageClassPropagationContext : public SourceEmitterBase +{ + SPIRVEmitSharedContext* m_sharedContext; + + IRModule* m_module; + // We will use a single work list of instructions that need + // to be considered for specialization or simplification, + // whether generic, existential, etc. + // + OrderedHashSet<IRInst*> workList; + + void addToWorkList(IRInst* inst) + { + if (workList.Add(inst)) + { + addUsersToWorkList(inst); + } + } + + void addUsersToWorkList(IRInst* inst) + { + for (auto use = inst->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + + addToWorkList(user); + } + } + + StorageClassPropagationContext(SPIRVEmitSharedContext* sharedContext, IRModule* module) + : m_sharedContext(sharedContext), m_module(module) + { + } + + void processGlobalParam(IRGlobalParam* inst) { processGlobalVar(inst); } + + void processGlobalVar(IRInst* inst) + { + auto oldPtrType = as<IRPtrTypeBase>(inst->getDataType()); + if (!oldPtrType) + return; + + // If the pointer type is already qualified with address spaces (such as + // lowered pointer type from a `HLSLStructuredBufferType`), make no + // further modifications. + if (oldPtrType->hasAddressSpace()) + { + addUsersToWorkList(inst); + return; + } + + auto varLayout = getVarLayout(inst); + if (!varLayout) + return; + + SpvStorageClass storageClass = SpvStorageClassPrivate; + for (auto rr : varLayout->getOffsetAttrs()) + { + switch (rr->getResourceKind()) + { + case LayoutResourceKind::Uniform: + case LayoutResourceKind::ShaderResource: + case LayoutResourceKind::DescriptorTableSlot: + storageClass = SpvStorageClassUniform; + break; + case LayoutResourceKind::VaryingInput: + storageClass = SpvStorageClassInput; + break; + case LayoutResourceKind::VaryingOutput: + storageClass = SpvStorageClassOutput; + break; + case LayoutResourceKind::UnorderedAccess: + storageClass = SpvStorageClassStorageBuffer; + break; + case LayoutResourceKind::PushConstantBuffer: + storageClass = SpvStorageClassPushConstant; + break; + default: + break; + } + } + auto rate = inst->getRate(); + if (as<IRGroupSharedRate>(rate)) + { + storageClass = SpvStorageClassWorkgroup; + } + IRBuilder builder; + builder.sharedBuilder = &m_sharedContext->m_sharedIRBuilder; + builder.setInsertBefore(inst); + auto newPtrType = + builder.getPtrType(oldPtrType->getOp(), oldPtrType->getValueType(), storageClass); + inst->setFullType(newPtrType); + addUsersToWorkList(inst); + return; + } + + void processCall(IRCall* inst) + { + auto funcValue = inst->getOperand(0); + if (auto targetIntrinsic = Slang::findBestTargetIntrinsicDecoration( + funcValue, m_sharedContext->m_targetRequest->getTargetCaps())) + { + SpvSnippet* snippet = m_sharedContext->getParsedSpvSnippet(targetIntrinsic); + if (!snippet) + return; + if (snippet->resultStorageClass != SpvStorageClassMax) + { + auto ptrType = as<IRPtrTypeBase>(inst->getDataType()); + if (!ptrType) + return; + IRBuilder builder; + builder.sharedBuilder = &m_sharedContext->m_sharedIRBuilder; + builder.setInsertBefore(inst); + auto qualPtrType = builder.getPtrType( + ptrType->getOp(), ptrType->getValueType(), snippet->resultStorageClass); + List<IRInst*> args; + for (UInt i = 0; i < inst->getArgCount(); i++) + args.add(inst->getArg(i)); + auto newCall = builder.emitCallInst(qualPtrType, funcValue, args); + inst->replaceUsesWith(newCall); + inst->removeAndDeallocate(); + addUsersToWorkList(newCall); + } + } + } + + void processGetElementPtr(IRGetElementPtr* inst) + { + if (auto ptrType = as<IRPtrTypeBase>(inst->getBase()->getDataType())) + { + if (!ptrType->hasAddressSpace()) + return; + auto oldResultType = as<IRPtrTypeBase>(inst->getDataType()); + if (oldResultType->getAddressSpace() != ptrType->getAddressSpace()) + { + IRBuilder builder; + builder.sharedBuilder = &m_sharedContext->m_sharedIRBuilder; + builder.setInsertBefore(inst); + auto newPtrType = builder.getPtrType( + oldResultType->getOp(), + oldResultType->getValueType(), + ptrType->getAddressSpace()); + auto newInst = + builder.emitElementAddress(newPtrType, inst->getBase(), inst->getIndex()); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + addUsersToWorkList(newInst); + } + } + } + + void processFieldAddress(IRFieldAddress* inst) + { + if (auto ptrType = as<IRPtrTypeBase>(inst->getBase()->getDataType())) + { + if (!ptrType->hasAddressSpace()) + return; + auto oldResultType = as<IRPtrTypeBase>(inst->getDataType()); + if (oldResultType->getAddressSpace() != ptrType->getAddressSpace()) + { + IRBuilder builder; + builder.sharedBuilder = &m_sharedContext->m_sharedIRBuilder; + builder.setInsertBefore(inst); + auto newPtrType = builder.getPtrType( + oldResultType->getOp(), + oldResultType->getValueType(), + ptrType->getAddressSpace()); + auto newInst = + builder.emitFieldAddress(newPtrType, inst->getBase(), inst->getField()); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + addUsersToWorkList(newInst); + } + } + } + + void processStructuredBufferType(IRHLSLStructuredBufferTypeBase* inst) + { + IRBuilder builder; + builder.sharedBuilder = &m_sharedContext->m_sharedIRBuilder; + builder.setInsertBefore(inst); + auto arrayType = builder.getUnsizedArrayType(inst->getElementType()); + auto ptrType = builder.getPtrType(kIROp_PtrType, arrayType, SpvStorageClassStorageBuffer); + inst->replaceUsesWith(ptrType); + inst->removeAndDeallocate(); + addUsersToWorkList(ptrType); + } + + void propagate() + { + addToWorkList(m_module->getModuleInst()); + while (workList.Count() != 0) + { + IRInst* inst = workList.getLast(); + workList.removeLast(); + switch (inst->getOp()) + { + case kIROp_GlobalParam: + processGlobalParam(as<IRGlobalParam>(inst)); + break; + case kIROp_GlobalVar: + processGlobalVar(as<IRGlobalVar>(inst)); + break; + case kIROp_Call: + processCall(as<IRCall>(inst)); + break; + case kIROp_getElementPtr: + processGetElementPtr(as<IRGetElementPtr>(inst)); + break; + case kIROp_FieldAddress: + processFieldAddress(as<IRFieldAddress>(inst)); + break; + case kIROp_HLSLStructuredBufferType: + case kIROp_HLSLRWStructuredBufferType: + processStructuredBufferType(as<IRHLSLStructuredBufferTypeBase>(inst)); + break; + default: + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + addToWorkList(child); + } + break; + } + } + } +}; + +void propagateStorageClass(SPIRVEmitSharedContext* sharedContext, IRModule* module) +{ + StorageClassPropagationContext context(sharedContext, module); + context.propagate(); +} + +void legalizeIRForSPIRV( + SPIRVEmitSharedContext* context, + IRModule* module, + DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + propagateStorageClass(context, module); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-spirv-legalize.h b/source/slang/slang-ir-spirv-legalize.h new file mode 100644 index 000000000..bf43430d8 --- /dev/null +++ b/source/slang/slang-ir-spirv-legalize.h @@ -0,0 +1,45 @@ +// slang-ir-spirv-legalize.h +#pragma once +#include "../core/slang-basic.h" +#include "slang-ir-spirv-snippet.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + +class DiagnosticSink; + +struct IRFunc; +struct IRModule; +class TargetRequest; + +struct SPIRVEmitSharedContext +{ + SharedIRBuilder m_sharedIRBuilder; + Dictionary<IRTargetIntrinsicDecoration*, RefPtr<SpvSnippet>> m_parsedSpvSnippets; + TargetRequest* m_targetRequest; + + SPIRVEmitSharedContext(IRModule* module, TargetRequest* target) + : m_sharedIRBuilder(module) + , m_targetRequest(target) + {} + + SpvSnippet* getParsedSpvSnippet(IRTargetIntrinsicDecoration* intrinsic) + { + RefPtr<SpvSnippet> snippet; + if (m_parsedSpvSnippets.TryGetValue(intrinsic, snippet)) + { + return snippet.Ptr(); + } + snippet = SpvSnippet::parse(intrinsic->getDefinition()); + m_parsedSpvSnippets[intrinsic] = snippet; + return snippet; + } +}; + +void legalizeIRForSPIRV( + SPIRVEmitSharedContext* context, + IRModule* module, + DiagnosticSink* sink); + +} diff --git a/source/slang/slang-ir-spirv-snippet.cpp b/source/slang/slang-ir-spirv-snippet.cpp new file mode 100644 index 000000000..4083f100d --- /dev/null +++ b/source/slang/slang-ir-spirv-snippet.cpp @@ -0,0 +1,124 @@ +// slang-ir-spirv-snippet.cpp + +#include"slang-ir-spirv-snippet.h" +#include "../core/slang-token-reader.h" + +namespace Slang +{ +static SpvStorageClass translateStorageClass(String name) +{ + if (name == "Uniform") + { + return SpvStorageClassUniform; + } + else if (name == "StorageBuffer") + { + return SpvStorageClassStorageBuffer; + } + return (SpvStorageClass)-1; +} + +RefPtr<SpvSnippet> SpvSnippet::parse(UnownedStringSlice definition) +{ + RefPtr<SpvSnippet> snippet = new SpvSnippet(); + try + { + Dictionary<String, int> mapInstNameToIndex; + Slang::Misc::TokenReader tokenReader(definition); + // A leading "*" at the beginning of the snip modifies $resultType with + // a storage class. + if (tokenReader.AdvanceIf("*")) + { + auto storageToken = tokenReader.ReadWord(); + snippet->resultStorageClass = translateStorageClass(storageToken); + + } + while (!tokenReader.IsEnd()) + { + SpvSnippet::ASMInst inst; + if (tokenReader.AdvanceIf("%")) + { + String instName = tokenReader.ReadToken().Content; + mapInstNameToIndex[instName] = (int)snippet->instructions.getCount(); + tokenReader.Read(Slang::Misc::TokenType::OpAssign); + } + inst.opCode = (SpvWord)tokenReader.ReadInt(); + bool insideOperandList = true; + while (insideOperandList) + { + ASMOperand operand = {ASMOperandType::SpvWord, 0}; + switch (tokenReader.NextToken().Type) + { + case Slang::Misc::TokenType::Semicolon: + insideOperandList = false; + tokenReader.ReadToken(); + break; + case Slang::Misc::TokenType::IntLiteral: + operand.type = SpvSnippet::ASMOperandType::SpvWord; + operand.content = tokenReader.ReadInt(); + inst.operands.add(operand); + break; + case Slang::Misc::TokenType::OpMod: + { + operand.type = SpvSnippet::ASMOperandType::InstReference; + auto refName = tokenReader.ReadToken().Content; + if (!mapInstNameToIndex.TryGetValue(refName, operand.content)) + { + SLANG_ASSERT(!"Invalid SPV ASM: referenced inst is not defined."); + } + inst.operands.add(operand); + } + break; + case Slang::Misc::TokenType::Identifier: + { + auto identifier = tokenReader.ReadToken().Content; + if (identifier.startsWith("_")) + { + operand.type = SpvSnippet::ASMOperandType::ObjectReference; + operand.content = + StringToInt(identifier.subString(1, identifier.getLength() - 1)); + inst.operands.add(operand); + } + else if (identifier == "resultType") + { + operand.type = SpvSnippet::ASMOperandType::ResultTypeId; + operand.content = -1; + if (tokenReader.AdvanceIf("*")) + { + // A "*" at operand qualifies the use of `resultType` with + // a storage class, but does not modify `resultType` itself. + auto storageClass = tokenReader.ReadWord(); + auto spvStorageClass = translateStorageClass(storageClass); + operand.content = spvStorageClass; + snippet->usedResultTypeStorageClasses.add(spvStorageClass); + } + inst.operands.add(operand); + } + else if (identifier == "resultId") + { + operand.type = SpvSnippet::ASMOperandType::ResultId; + inst.operands.add(operand); + } + else + { + SLANG_ASSERT(!"Invalid SPV ASM operand."); + } + } + break; + default: + insideOperandList = false; + break; + } + } + snippet->instructions.add(inst); + } + } + catch (const Slang::Misc::TextFormatException&) + { + SLANG_ASSERT(!"Invalid ASM format."); + } + return snippet; +} + + +} diff --git a/source/slang/slang-ir-spirv-snippet.h b/source/slang/slang-ir-spirv-snippet.h new file mode 100644 index 000000000..74a9b8cd7 --- /dev/null +++ b/source/slang/slang-ir-spirv-snippet.h @@ -0,0 +1,61 @@ +// slang-ir-spirv-legalize.h +#pragma once +#include "../core/slang-basic.h" +#include "spirv/unified1/spirv.h" + +namespace Slang +{ +// +// [2.2: Terms] +// +// > Word: 32 bits. +// +// Despite the importance to SPIR-V, the `spirv.h` header doesn't +// define a type for words, so we'll do it here. + +/// A SPIR-V word. +typedef uint32_t SpvWord; + +/// Represents a parsed Spv ASM from intrinsic definition. +struct SpvSnippet : public RefObject +{ + enum class ASMOperandType + { + // Plain SpvWord to inline without modifications. + SpvWord, + // Represents the result type of the intrinsic. + ResultTypeId, + // Represents the result Id of the ASM inst. + ResultId, + // Represents a reference to an intrinsic argument (e.g. `_1`). + ObjectReference, + // Represents a reference to an ASM inst (e.g. `%t`). + InstReference, + }; + + struct ASMOperand + { + ASMOperandType type; + + // The value of the spv word when type is `SpvWord`, or + // the reference name when type is `ObjectReference` + // (e.g. an argument reference (_1) has `content` == 1). + int content; + }; + + struct ASMInst + { + SpvWord opCode; + List<ASMOperand> operands; + }; + + List<ASMInst> instructions; + List<SpvStorageClass> usedResultTypeStorageClasses; + + SpvStorageClass resultStorageClass = SpvStorageClassMax; + + static RefPtr<SpvSnippet> parse(UnownedStringSlice definition); +}; + + +} diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index fe60fb480..60aaafa83 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2604,6 +2604,12 @@ namespace Slang operands); } + IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, IRIntegerValue addressSpace) + { + IRInst* operands[] = {valueType, getIntValue(getIntType(), addressSpace)}; + return (IRPtrType*)getType(op, 2, operands); + } + IRArrayTypeBase* IRBuilder::getArrayTypeBase( IROp op, IRType* elementType, diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 9eb03c269..7542a883a 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1116,6 +1116,13 @@ struct IRPtrTypeBase : IRType { IRType* getValueType() { return (IRType*)getOperand(0); } + bool hasAddressSpace() { return getOperandCount() > 1; } + + IRIntegerValue getAddressSpace() + { + return getOperandCount() > 1 ? static_cast<IRIntLit*>(getOperand(1))->getValue() : -1; + } + IR_PARENT_ISA(PtrTypeBase) }; diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index ef92558cc..913be346e 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1100,6 +1100,11 @@ void TargetRequest::addCapability(CapabilityAtom capability) cookedCapabilities = CapabilitySet::makeEmpty(); } +void TargetRequest::setDirectSPIRVEmitMode() +{ + m_emitSPIRVDirectly = true; + cookedCapabilities.makeEmpty(); +} CapabilitySet TargetRequest::getTargetCaps() { @@ -1131,9 +1136,18 @@ CapabilitySet TargetRequest::getTargetCaps() case CodeGenTarget::GLSL: case CodeGenTarget::GLSL_Vulkan: case CodeGenTarget::GLSL_Vulkan_OneDesc: + atoms.add(CapabilityAtom::GLSL); + break; case CodeGenTarget::SPIRV: case CodeGenTarget::SPIRVAssembly: - atoms.add(CapabilityAtom::GLSL); + if (m_emitSPIRVDirectly) + { + atoms.add(CapabilityAtom::SPIRV_DIRECT); + } + else + { + atoms.add(CapabilityAtom::GLSL); + } break; case CodeGenTarget::HLSL: diff --git a/tests/spirv/direct-spirv-compute-simple.slang b/tests/spirv/direct-spirv-compute-simple.slang new file mode 100644 index 000000000..39b9074ed --- /dev/null +++ b/tests/spirv/direct-spirv-compute-simple.slang @@ -0,0 +1,23 @@ +// direct-spirv-compute-simple.slang + +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -xslang -emit-spirv-directly + +// Test runinng a shader generated from direct SPIR-V emit. + +//TEST_INPUT:set resultBuffer = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<uint> resultBuffer; + +[numthreads(4,1,1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint threadId = dispatchThreadID.x; + uint result = threadId + 1; + result = result - 1; + result = result * 2; + result = result / 2; + result = result % 3; + result = (result ^ 7); + result = (result & 7); + result = (result | 8); + resultBuffer[threadId] = result; +} diff --git a/tests/spirv/direct-spirv-compute-simple.slang.expected.txt b/tests/spirv/direct-spirv-compute-simple.slang.expected.txt new file mode 100644 index 000000000..4fc6bca7a --- /dev/null +++ b/tests/spirv/direct-spirv-compute-simple.slang.expected.txt @@ -0,0 +1,4 @@ +F +E +D +F
\ No newline at end of file diff --git a/tests/spirv/direct-spirv-control-flow-2.slang b/tests/spirv/direct-spirv-control-flow-2.slang new file mode 100644 index 000000000..cc908100e --- /dev/null +++ b/tests/spirv/direct-spirv-control-flow-2.slang @@ -0,0 +1,47 @@ +// direct-spirv-control-flow-2.slang + +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -output-using-type -xslang -emit-spirv-directly + +// Test direct SPIR-V emit on control flows. + +//TEST_INPUT:set resultBuffer = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<uint> resultBuffer; + +uint test(uint p) +{ + int result = 0; + for (int i = 0; i < 5; i++) + { + result += i*2; + } + switch (p) + { + case 0: + result = result - 1; + break; + case 1: + result = result + 1; + break; + default: + result = result * 2; + break; + } + if (p > 2) + { + switch (p) + { + case 3: + result++; + break; + } + } + return result; +} + +[numthreads(4,1,1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint threadId = dispatchThreadID.x; + uint result = test(threadId); + resultBuffer[threadId] = result; +} diff --git a/tests/spirv/direct-spirv-control-flow-2.slang.expected.txt b/tests/spirv/direct-spirv-control-flow-2.slang.expected.txt new file mode 100644 index 000000000..36929d66f --- /dev/null +++ b/tests/spirv/direct-spirv-control-flow-2.slang.expected.txt @@ -0,0 +1,5 @@ +type: uint32_t +19 +21 +40 +41 diff --git a/tests/spirv/direct-spirv-control-flow.slang b/tests/spirv/direct-spirv-control-flow.slang new file mode 100644 index 000000000..9efddeb12 --- /dev/null +++ b/tests/spirv/direct-spirv-control-flow.slang @@ -0,0 +1,30 @@ +// direct-spirv-control-flow.slang + +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -xslang -emit-spirv-directly + +// Test direct SPIRV emit on control fl. + +//TEST_INPUT:set resultBuffer = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer<uint> resultBuffer; + +uint test(uint p) +{ + int result = 0; + if (p == 0) + { + result = 5; + } + else + { + result = 6; + } + return result; +} + +[numthreads(4,1,1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint threadId = dispatchThreadID.x; + uint result = test(threadId); + resultBuffer[threadId] = result; +} diff --git a/tests/spirv/direct-spirv-control-flow.slang.expected.txt b/tests/spirv/direct-spirv-control-flow.slang.expected.txt new file mode 100644 index 000000000..c0bcc1c4a --- /dev/null +++ b/tests/spirv/direct-spirv-control-flow.slang.expected.txt @@ -0,0 +1,4 @@ +5 +6 +6 +6 diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp index 592cbaac1..88770dbb9 100644 --- a/tools/gfx/vulkan/render-vk.cpp +++ b/tools/gfx/vulkan/render-vk.cpp @@ -5573,6 +5573,7 @@ Result VKDevice::initVulkanInstanceAndDevice(bool useValidationLayer) extendedFeatures.bufferDeviceAddressFeatures.pNext = (void*)deviceCreateInfo.pNext; deviceCreateInfo.pNext = &extendedFeatures.bufferDeviceAddressFeatures; deviceExtensions.add(VK_KHR_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME); + m_features.add("buffer-device-address"); } @@ -5606,6 +5607,7 @@ Result VKDevice::initVulkanInstanceAndDevice(bool useValidationLayer) deviceCreateInfo.enabledExtensionCount = uint32_t(deviceExtensions.getCount()); deviceCreateInfo.ppEnabledExtensionNames = deviceExtensions.getBuffer(); + if (m_api.vkCreateDevice(m_api.m_physicalDevice, &deviceCreateInfo, nullptr, &m_device) != VK_SUCCESS) return SLANG_FAIL; SLANG_RETURN_ON_FAIL(m_api.initDeviceProcs(m_device)); diff --git a/tools/gfx/vulkan/vk-api.h b/tools/gfx/vulkan/vk-api.h index 746648470..0bb3339fb 100644 --- a/tools/gfx/vulkan/vk-api.h +++ b/tools/gfx/vulkan/vk-api.h @@ -226,7 +226,7 @@ struct VulkanExtendedFeatureProperties VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES}; // Inline uniform block features VkPhysicalDeviceInlineUniformBlockFeaturesEXT inlineUniformBlockFeatures = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES}; + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_INLINE_UNIFORM_BLOCK_FEATURES_EXT}; }; struct VulkanApi |
