summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2021-08-12 13:14:15 -0700
committerGitHub <noreply@github.com>2021-08-12 13:14:15 -0700
commit6406523511037987d8b8ab881aea41389afd57eb (patch)
tree79f24b6cba377340c2f4d3dcf9fed78fc586f3e0
parent389d21d982da34815b65b10cae63088c397eecc8 (diff)
Further implementation of SPIRV direct emit. (#1920)
* Further implementation of SPIRV direct emit. This change implements: - Struct, Vector, Matrix and Unsized Array types. - Basic arithmetic opcodes, vector construct, swizzle etc. - getElementPtr, getElement, fieldAddress, extractField. - SPIRV target intrinsics with SPIRV asm code in stdlib. - RWStructuredBuffer and StructuredBuffer. - Pointer storage class propagation. - Control flow. * Fix.
-rw-r--r--build/visual-studio/slang/slang.vcxproj6
-rw-r--r--build/visual-studio/slang/slang.vcxproj.filters18
-rw-r--r--source/core/slang-token-reader.h2
-rw-r--r--source/slang/hlsl.meta.slang6
-rw-r--r--source/slang/slang-capability-defs.h1
-rw-r--r--source/slang/slang-compiler.cpp1
-rwxr-xr-xsource/slang/slang-compiler.h3
-rw-r--r--source/slang/slang-emit-base.cpp55
-rw-r--r--source/slang/slang-emit-base.h29
-rw-r--r--source/slang/slang-emit-c-like.cpp49
-rw-r--r--source/slang/slang-emit-c-like.h14
-rw-r--r--source/slang/slang-emit-spirv.cpp1261
-rw-r--r--source/slang/slang-emit.cpp7
-rw-r--r--source/slang/slang-ir-insts.h9
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp258
-rw-r--r--source/slang/slang-ir-spirv-legalize.h45
-rw-r--r--source/slang/slang-ir-spirv-snippet.cpp124
-rw-r--r--source/slang/slang-ir-spirv-snippet.h61
-rw-r--r--source/slang/slang-ir.cpp6
-rw-r--r--source/slang/slang-ir.h7
-rw-r--r--source/slang/slang.cpp16
-rw-r--r--tests/spirv/direct-spirv-compute-simple.slang23
-rw-r--r--tests/spirv/direct-spirv-compute-simple.slang.expected.txt4
-rw-r--r--tests/spirv/direct-spirv-control-flow-2.slang47
-rw-r--r--tests/spirv/direct-spirv-control-flow-2.slang.expected.txt5
-rw-r--r--tests/spirv/direct-spirv-control-flow.slang30
-rw-r--r--tests/spirv/direct-spirv-control-flow.slang.expected.txt4
-rw-r--r--tools/gfx/vulkan/render-vk.cpp2
-rw-r--r--tools/gfx/vulkan/vk-api.h2
29 files changed, 1989 insertions, 106 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index f175d6a31..0ccd5ed70 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -209,6 +209,7 @@
<ClInclude Include="..\..\..\source\slang\slang-diagnostics.h" />
<ClInclude Include="..\..\..\source\slang\slang-doc-extractor.h" />
<ClInclude Include="..\..\..\source\slang\slang-doc-markdown-writer.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-emit-base.h" />
<ClInclude Include="..\..\..\source\slang\slang-emit-c-like.h" />
<ClInclude Include="..\..\..\source\slang\slang-emit-cpp.h" />
<ClInclude Include="..\..\..\source\slang\slang-emit-cuda.h" />
@@ -263,6 +264,8 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-function-call.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-resources.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-spirv-legalize.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-spirv-snippet.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-ssa.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-string-hash.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-strip-witness-tables.h" />
@@ -335,6 +338,7 @@
<ClCompile Include="..\..\..\source\slang\slang-diagnostics.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-doc-extractor.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-doc-markdown-writer.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-emit-base.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-emit-c-like.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-emit-cpp.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-emit-cuda.cpp" />
@@ -389,6 +393,8 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-function-call.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-resources.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-spirv-legalize.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-spirv-snippet.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-ssa.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-string-hash.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-strip-witness-tables.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index 1697a385c..a7affb00a 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -78,6 +78,9 @@
<ClInclude Include="..\..\..\source\slang\slang-doc-markdown-writer.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-emit-base.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-emit-c-like.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -240,6 +243,12 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-spirv-legalize.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-spirv-snippet.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-ssa.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -452,6 +461,9 @@
<ClCompile Include="..\..\..\source\slang\slang-doc-markdown-writer.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-emit-base.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-emit-c-like.cpp">
<Filter>Source Files</Filter>
</ClCompile>
@@ -614,6 +626,12 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-spirv-legalize.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-spirv-snippet.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-ssa.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/core/slang-token-reader.h b/source/core/slang-token-reader.h
index 0d59eea76..26539732c 100644
--- a/source/core/slang-token-reader.h
+++ b/source/core/slang-token-reader.h
@@ -73,7 +73,7 @@ namespace Misc {
TokenType Type = TokenType::Unknown;
String Content;
CodePosition Position;
- TokenFlags flags;
+ TokenFlags flags = 0;
Token() = default;
Token(TokenType type, const String & content, int line, int col, int pos, String fileName, TokenFlags flags = 0)
: flags(flags)
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index bb573c2b2..dd4f95cf5 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -199,13 +199,15 @@ struct StructuredBuffer
out uint numStructs,
out uint stride);
- __target_intrinsic(glsl, "$0._data[$1]")
+ __target_intrinsic(glsl, "$0._data[$1]")
+ __target_intrinsic(spirv_direct, "%addr = 65 resultType*StorageBuffer resultId _0 _1; 61 resultType resultId %addr;")
T Load(int location);
T Load(int location, out uint status);
__subscript(uint index) -> T
{
__target_intrinsic(glsl, "$0._data[$1]")
+ __target_intrinsic(spirv_direct, "%addr = 65 resultType*StorageBuffer resultId _0 _1; 61 resultType resultId %addr;")
get;
};
};
@@ -629,12 +631,14 @@ struct $(item.name)
uint IncrementCounter();
__target_intrinsic(glsl, "$0._data[$1]")
+ __target_intrinsic(spirv_direct, "%addr = 65 resultType*StorageBuffer resultId _0 _1; 61 resultType resultId %addr;")
T Load(int location);
T Load(int location, out uint status);
__subscript(uint index) -> T
{
__target_intrinsic(glsl, "$0._data[$1]")
+ __target_intrinsic(spirv_direct, "*StorageBuffer 65 resultType resultId _0 _1")
ref;
}
};
diff --git a/source/slang/slang-capability-defs.h b/source/slang/slang-capability-defs.h
index f66add15b..fc60f4dfa 100644
--- a/source/slang/slang-capability-defs.h
+++ b/source/slang/slang-capability-defs.h
@@ -55,6 +55,7 @@ SLANG_CAPABILITY_ATOM0(GLSL, glsl, Concrete,TargetFormat,0)
SLANG_CAPABILITY_ATOM0(C, c, Concrete,TargetFormat,0)
SLANG_CAPABILITY_ATOM0(CPP, cpp, Concrete,TargetFormat,0)
SLANG_CAPABILITY_ATOM0(CUDA, cuda, Concrete,TargetFormat,0)
+SLANG_CAPABILITY_ATOM0(SPIRV_DIRECT, spirv_direct, Concrete, TargetFormat, 0)
// We have multiple capabilities for the various SPIR-V versions,
// arranged so that they inherit from one another to represent which versions
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index d909a190c..028886c7e 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -1445,6 +1445,7 @@ namespace Slang
if (target == CodeGenTarget::SPIRV && compileRequest->shouldEmitSPIRVDirectly)
{
List<uint8_t> spirv;
+ targetReq->setDirectSPIRVEmitMode();
SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPointsDirectly(compileRequest, entryPointIndices, targetReq, spirv));
auto spirvBlob = ListBlob::moveCreate(spirv);
downstreamResult = new BlobDownstreamCompileResult(DownstreamDiagnostics(), spirvBlob);
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 01f23918b..b829fd0ee 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -1190,6 +1190,8 @@ namespace Slang
return (targetFlags & SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM) != 0;
}
+ void setDirectSPIRVEmitMode();
+
Linkage* getLinkage() { return linkage; }
CodeGenTarget getTarget() { return format; }
Profile getTargetProfile() { return targetProfile; }
@@ -1217,6 +1219,7 @@ namespace Slang
List<CapabilityAtom> rawCapabilities;
CapabilitySet cookedCapabilities;
LineDirectiveMode lineDirectiveMode = LineDirectiveMode::Default;
+ bool m_emitSPIRVDirectly = false;
};
/// Are we generating code for a D3D API?
diff --git a/source/slang/slang-emit-base.cpp b/source/slang/slang-emit-base.cpp
new file mode 100644
index 000000000..d00b723ab
--- /dev/null
+++ b/source/slang/slang-emit-base.cpp
@@ -0,0 +1,55 @@
+#include "slang-emit-base.h"
+
+namespace Slang
+{
+
+IRInst* SourceEmitterBase::getSpecializedValue(IRSpecialize* specInst)
+{
+ auto base = specInst->getBase();
+
+ // It is possible to have a `specialize(...)` where the first
+ // operand is also a `specialize(...)`, so that we need to
+ // look at what declaration is being specialized at the inner
+ // step to find the one being specialized at the outer step.
+ //
+ while (auto baseSpecialize = as<IRSpecialize>(base))
+ {
+ base = getSpecializedValue(baseSpecialize);
+ }
+
+ auto baseGeneric = as<IRGeneric>(base);
+ if (!baseGeneric)
+ return base;
+
+ auto lastBlock = baseGeneric->getLastBlock();
+ if (!lastBlock)
+ return base;
+
+ auto returnInst = as<IRReturnVal>(lastBlock->getTerminator());
+ if (!returnInst)
+ return base;
+
+ return returnInst->getVal();
+}
+
+void SourceEmitterBase::handleRequiredCapabilities(IRInst* inst)
+{
+ auto decoratedValue = inst;
+ while (auto specInst = as<IRSpecialize>(decoratedValue))
+ {
+ decoratedValue = getSpecializedValue(specInst);
+ }
+
+ handleRequiredCapabilitiesImpl(decoratedValue);
+}
+
+IRVarLayout* SourceEmitterBase::getVarLayout(IRInst* var)
+{
+ auto decoration = var->findDecoration<IRLayoutDecoration>();
+ if (!decoration)
+ return nullptr;
+
+ return as<IRVarLayout>(decoration->getLayout());
+}
+
+}
diff --git a/source/slang/slang-emit-base.h b/source/slang/slang-emit-base.h
new file mode 100644
index 000000000..ffbf56618
--- /dev/null
+++ b/source/slang/slang-emit-base.h
@@ -0,0 +1,29 @@
+// slang-emit-base.h
+#ifndef SLANG_EMIT_BASE_H
+#define SLANG_EMIT_BASE_H
+
+#include "../core/slang-basic.h"
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-restructure.h"
+
+namespace Slang
+{
+
+class SourceEmitterBase : public RefObject
+{
+public:
+ IRInst* getSpecializedValue(IRSpecialize* specInst);
+
+ /// Inspect the capabilities required by `inst` (according to its decorations),
+ /// and ensure that those capabilities have been detected and stored in the
+ /// target-specific extension tracker.
+ void handleRequiredCapabilities(IRInst* inst);
+ virtual void handleRequiredCapabilitiesImpl(IRInst* inst) { SLANG_UNUSED(inst); }
+
+ static IRVarLayout* getVarLayout(IRInst* var);
+};
+
+}
+#endif
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 84c369d40..f9d71beb9 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -1450,17 +1450,6 @@ void CLikeSourceEmitter::_emitCallArgList(IRCall* inst)
m_writer->emit(")");
}
-void CLikeSourceEmitter::handleRequiredCapabilities(IRInst* inst)
-{
- auto decoratedValue = inst;
- while (auto specInst = as<IRSpecialize>(decoratedValue))
- {
- decoratedValue = getSpecializedValue(specInst);
- }
-
- handleRequiredCapabilitiesImpl(decoratedValue);
-}
-
void CLikeSourceEmitter::emitCallExpr(IRCall* inst, EmitOpInfo outerPrec)
{
auto funcValue = inst->getOperand(0);
@@ -2164,15 +2153,6 @@ void CLikeSourceEmitter::emitSemantics(IRInst* inst)
emitSemanticsImpl(inst);
}
-IRVarLayout* CLikeSourceEmitter::getVarLayout(IRInst* var)
-{
- auto decoration = var->findDecoration<IRLayoutDecoration>();
- if (!decoration)
- return nullptr;
-
- return as<IRVarLayout>(decoration->getLayout());
-}
-
void CLikeSourceEmitter::emitLayoutSemantics(IRInst* inst, char const* uniformSemanticSpelling)
{
emitLayoutSemanticsImpl(inst, uniformSemanticSpelling);
@@ -2781,35 +2761,6 @@ void CLikeSourceEmitter::emitParamTypeImpl(IRType* type, String const& name)
emitType(type, name);
}
-IRInst* CLikeSourceEmitter::getSpecializedValue(IRSpecialize* specInst)
-{
- auto base = specInst->getBase();
-
- // It is possible to have a `specialize(...)` where the first
- // operand is also a `specialize(...)`, so that we need to
- // look at what declaration is being specialized at the inner
- // step to find the one being specialized at the outer step.
- //
- while(auto baseSpecialize = as<IRSpecialize>(base))
- {
- base = getSpecializedValue(baseSpecialize);
- }
-
- auto baseGeneric = as<IRGeneric>(base);
- if (!baseGeneric)
- return base;
-
- auto lastBlock = baseGeneric->getLastBlock();
- if (!lastBlock)
- return base;
-
- auto returnInst = as<IRReturnVal>(lastBlock->getTerminator());
- if (!returnInst)
- return base;
-
- return returnInst->getVal();
-}
-
void CLikeSourceEmitter::emitFuncDecl(IRFunc* func)
{
// We don't want to emit declarations for operations
diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h
index 90db7476c..f699ed255 100644
--- a/source/slang/slang-emit-c-like.h
+++ b/source/slang/slang-emit-c-like.h
@@ -6,6 +6,7 @@
#include "slang-compiler.h"
+#include "slang-emit-base.h"
#include "slang-emit-precedence.h"
#include "slang-emit-source-writer.h"
@@ -16,7 +17,7 @@
namespace Slang
{
-class CLikeSourceEmitter: public RefObject
+class CLikeSourceEmitter: public SourceEmitterBase
{
public:
struct Desc
@@ -292,8 +293,6 @@ public:
void emitSemantics(IRInst* inst);
void emitSemanticsUsingVarLayout(IRVarLayout* varLayout);
- static IRVarLayout* getVarLayout(IRInst* var);
-
void emitLayoutSemantics(IRInst* inst, char const* uniformSemanticSpelling = "register");
// When we are about to traverse an edge from one block to another,
@@ -323,8 +322,6 @@ public:
void emitParamType(IRType* type, String const& name) { emitParamTypeImpl(type, name); }
- IRInst* getSpecializedValue(IRSpecialize* specInst);
-
void emitFuncDecl(IRFunc* func);
IREntryPointLayout* getEntryPointLayout(IRFunc* func);
@@ -453,15 +450,8 @@ public:
virtual bool tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType) { SLANG_UNUSED(varDecl); SLANG_UNUSED(varType); return false; }
virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) { SLANG_UNUSED(inst); SLANG_UNUSED(inOuterPrec); return false; }
- /// Inspect the capabilities required by `inst` (according to its decorations),
- /// and ensure that those capabilities have been detected and stored in the
- /// target-specific extension tracker.
- void handleRequiredCapabilities(IRInst* inst);
- virtual void handleRequiredCapabilitiesImpl(IRInst* inst) { SLANG_UNUSED(inst); }
-
virtual void emitPostKeywordTypeAttributesImpl(IRInst* inst) { SLANG_UNUSED(inst); }
-
void _emitArrayType(IRArrayType* arrayType, DeclaratorInfo* declarator);
void _emitUnsizedArrayType(IRUnsizedArrayType* arrayType, DeclaratorInfo* declarator);
void _emitType(IRType* type, DeclaratorInfo* declarator);
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index fe039feb0..37fd673ed 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -2,11 +2,14 @@
#include "slang-emit.h"
#include "slang-compiler.h"
+#include "slang-emit-base.h"
+
#include "slang-ir.h"
#include "slang-ir-insts.h"
-
+#include "slang-ir-layout.h"
+#include "slang-ir-spirv-snippet.h"
+#include "slang-ir-spirv-legalize.h"
#include "spirv/unified1/spirv.h"
-
#include "../core/slang-memory-arena.h"
namespace Slang
@@ -36,16 +39,7 @@ namespace Slang
// [2.3: Physical Layout of a SPIR-V Module and Instruction]
//
// > A SPIR-V module is a single linear stream of words.
-//
-// [2.2: Terms]
-//
-// > Word: 32 bits.
-//
-// Despite the importance to SPIR-V, the `spirv.h` header doesn't
-// define a type for words, so we'll do it here.
- /// A SPIR-V word.
-typedef uint32_t SpvWord;
// [2.3: Physical Layout of a SPIR-V Module and Instruction]
//
@@ -268,6 +262,14 @@ void SpvInstParent::dumpTo(List<SpvWord>& ioWords)
}
}
+/// The context for inlining a SPV assembly snippet.
+struct SpvSnippetEmitContext
+{
+ SpvInst* resultType;
+ Dictionary<SpvStorageClass, IRInst*> qualifiedResultTypes;
+ List<SpvWord> argumentIds;
+};
+
// Now that we've defined the intermediate data structures we will
// use to represent SPIR-V code during emission, we will move on
// to defining the main context type that will drive SPIR-V
@@ -275,10 +277,14 @@ void SpvInstParent::dumpTo(List<SpvWord>& ioWords)
/// Context used for translating a Slang IR module to SPIR-V
struct SPIRVEmitContext
+ : public SourceEmitterBase
+ , public SPIRVEmitSharedContext
{
/// The Slang IR module being translated
IRModule* m_irModule;
+ DiagnosticSink* m_sink;
+
// [2.2: Terms]
//
// > <id>: A numerical name; the name used to refer to an object, a type,
@@ -385,12 +391,36 @@ struct SPIRVEmitContext
/// Map a Slang IR instruction to the corresponding SPIR-V instruction
Dictionary<IRInst*, SpvInst*> m_mapIRInstToSpvInst;
+ // Sometimes we need to reserve an ID for an `IRInst` without actually
+ // emitting it. We use `m_mapIRInstToSpvID` to hold all reserved SpvIDs.
+ // Use `getIRInstSpvID` to obtain an SpvID for an `IRInst` if the
+ // `IRInst` may not have been emitted.
+ Dictionary<IRInst*, SpvWord> m_mapIRInstToSpvID;
+
/// Register that `irInst` maps to `spvInst`
void registerInst(IRInst* irInst, SpvInst* spvInst)
{
m_mapIRInstToSpvInst.Add(irInst, spvInst);
}
+ /// Get or reserve a SpvID for an IR value.
+ SpvWord getIRInstSpvID(IRInst* inst)
+ {
+ // If we have already emitted an SpvInst for `inst`, return its ID.
+ SpvInst* spvInst = nullptr;
+ if (m_mapIRInstToSpvInst.TryGetValue(inst, spvInst))
+ return getID(spvInst);
+ // Check if we have reserved an ID for `inst`.
+ SpvWord result = 0;
+ if (m_mapIRInstToSpvID.TryGetValue(inst, result))
+ return result;
+ // Otherwise, reserve a new ID for inst, and register it in `m_mapIRInstToSpvID`.
+ result = m_nextID;
+ ++m_nextID;
+ m_mapIRInstToSpvID[inst] = result;
+ return result;
+ }
+
// When we are emitting an instruction that can produce
// a result, we will allocate an <id> to it so that other
// instructions can refer to it.
@@ -409,6 +439,18 @@ struct SPIRVEmitContext
return id;
}
+ struct VectorTypeKey
+ {
+ BaseType baseType;
+ IRIntegerValue elementCount;
+ HashCode getHashCode() { return combineHash((int)baseType, (HashCode)elementCount); }
+ bool operator==(const VectorTypeKey& other)
+ {
+ return baseType == other.baseType && elementCount == other.elementCount;
+ }
+ };
+ Dictionary<VectorTypeKey, SpvInst*> m_vectorTypes;
+
// We will build up `SpvInst`s in a stateful fashion,
// mostly for convenience. We could in theory compute
// the number of words each instruction needs, then allocate
@@ -467,6 +509,8 @@ struct SPIRVEmitContext
if(irInst)
{
registerInst(irInst, spvInst);
+ // If we have reserved an SpvID for `irInst`, make sure to use it.
+ m_mapIRInstToSpvID.TryGetValue(irInst, spvInst->id);
}
// Set up the scope
@@ -561,9 +605,6 @@ struct SPIRVEmitContext
/// Emit an operand to the current instruction, which references `src` by its <id>
void emitOperand(IRInst* src)
{
- // We first ensure that the `src` instruction has been emitted,
- // and then handle it as for any other <id> operand.
- //
SpvInst* spvSrc = ensureInst(src);
emitOperand(getID(spvSrc));
}
@@ -629,6 +670,25 @@ struct SPIRVEmitContext
emitOperand(getID(m_currentInst));
}
+ void emitOperand(SpvDecoration decoration) { emitOperand((SpvWord)decoration); }
+
+ void emitOperand(SpvBuiltIn builtin) { emitOperand((SpvWord)builtin); }
+ void emitOperand(SpvStorageClass val) { emitOperand((SpvWord)val); }
+
+ Dictionary<IRIntegerValue, SpvInst*> m_spvIntConstants;
+ SpvInst* emitConstant(IRIntegerValue val, IRType* type)
+ {
+ SpvInst* result = nullptr;
+ if (m_spvIntConstants.TryGetValue(val, result))
+ return result;
+ return emitInst(
+ getSection(SpvLogicalSectionID::Constants),
+ nullptr,
+ SpvOpConstant,
+ type,
+ kResultID,
+ (SpvWord)val);
+ }
// As another convenience, there are often cases where
// we will want to emit all of the operands of some
// IR instruction as <id> operands of a SPIR-V
@@ -742,6 +802,16 @@ struct SPIRVEmitContext
return spvInst;
}
+ template<typename OperandEmitFunc>
+ SpvInst* emitInstCustomOperandFunc(SpvInstParent* parent, IRInst* irInst, SpvOp opcode, const OperandEmitFunc& f)
+ {
+ InstConstructScope scopeInst(this, opcode, irInst);
+ SpvInst* spvInst = scopeInst;
+ f();
+ parent->addInst(spvInst);
+ return spvInst;
+ }
+
// Now that we've gotten the core infrastructure out of the way,
// let's start looking at emitting some instructions that make
// up a SPIR-V module.
@@ -826,14 +896,110 @@ struct SPIRVEmitContext
CASE(kIROp_DoubleType, 64);
#undef CASE
-
- // > OpTypeVector
- // > OpTypeMatrix
+ case kIROp_PtrType:
+ case kIROp_RefType:
+ case kIROp_OutType:
+ case kIROp_InOutType:
+ {
+ SpvStorageClass storageClass = SpvStorageClassFunction;
+ auto ptrType = as<IRPtrTypeBase>(inst);
+ if (ptrType->hasAddressSpace())
+ storageClass = (SpvStorageClass)ptrType->getAddressSpace();
+ return emitInst(
+ getSection(SpvLogicalSectionID::Types),
+ inst,
+ SpvOpTypePointer,
+ kResultID,
+ storageClass,
+ inst->getOperand(0));
+ }
+ case kIROp_StructType:
+ {
+ return emitInstCustomOperandFunc(
+ getSection(SpvLogicalSectionID::Types), inst, SpvOpTypeStruct, [&]() {
+ emitOperand(kResultID);
+ for (auto field : static_cast<IRStructType*>(inst)->getFields())
+ {
+ emitOperand(field->getFieldType());
+ // TODO: decorate offset
+ }
+ });
+ }
+ case kIROp_VectorType:
+ {
+ auto vectorType = static_cast<IRVectorType*>(inst);
+ return ensureVectorType(
+ static_cast<IRBasicType*>(vectorType->getElementType())->getBaseType(),
+ static_cast<IRIntLit*>(vectorType->getElementCount())->getValue(),
+ vectorType);
+ }
+ case kIROp_MatrixType:
+ {
+ auto matrixType = static_cast<IRMatrixType*>(inst);
+ auto vectorSpvType = ensureVectorType(
+ static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(),
+ static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(),
+ nullptr);
+ auto matrixSPVType = emitInst(
+ getSection(SpvLogicalSectionID::Types),
+ inst,
+ SpvOpTypeMatrix,
+ kResultID,
+ vectorSpvType,
+ (SpvWord)static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue());
+ // TODO: properly compute matrix stride.
+ auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue();
+ uint32_t stride = 0;
+ switch (columnCount)
+ {
+ case 1:
+ stride = 4;
+ break;
+ case 2:
+ stride = 8;
+ break;
+ case 3:
+ case 4:
+ stride = 16;
+ break;
+ default:
+ break;
+ }
+ emitInst(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ SpvOpDecorate,
+ matrixSPVType,
+ SpvDecorationRowMajor,
+ SpvDecorationMatrixStride,
+ stride);
+ return matrixSPVType;
+ }
+ case kIROp_UnsizedArrayType:
+ {
+ auto elementType = static_cast<IRUnsizedArrayType*>(inst)->getElementType();
+ auto runtimeArrayType = emitInst(
+ getSection(SpvLogicalSectionID::Types),
+ nullptr,
+ SpvOpTypeRuntimeArray,
+ kResultID,
+ elementType);
+ // TODO: properly decorate stride.
+ IRSizeAndAlignment sizeAndAlignment;
+ getNaturalSizeAndAlignment(this->m_targetRequest, elementType, &sizeAndAlignment);
+ emitInst(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ SpvOpDecorate,
+ runtimeArrayType,
+ SpvDecorationArrayStride,
+ (SpvWord)sizeAndAlignment.getStride());
+ return runtimeArrayType;
+ }
// > OpTypeImage
// > OpTypeSampler
// > OpTypeArray
// > OpTypeRuntimeArray
- // > OpTypeStruct
// > OpTypeOpaque
// > OpTypePointer
@@ -858,6 +1024,15 @@ struct SPIRVEmitContext
//
return emitFunc(as<IRFunc>(inst));
+ case kIROp_BoolLit:
+ case kIROp_IntLit:
+ case kIROp_FloatLit:
+ return emitLit(inst);
+
+ case kIROp_GlobalParam:
+ return emitGlobalParam(as<IRGlobalParam>(inst));
+ case kIROp_GlobalVar:
+ return emitGlobalVar(as<IRGlobalVar>(inst));
// ...
default:
@@ -866,6 +1041,162 @@ struct SPIRVEmitContext
}
}
+ // Ensures an SpvInst for the specified vector type is emitted.
+ // `inst` represents an optional `IRVectorType` inst representing the vector type, if
+ // it is nullptr, this function will create one.
+ SpvInst* ensureVectorType(BaseType baseType, IRIntegerValue elementCount, IRVectorType* inst)
+ {
+ VectorTypeKey key = {baseType, elementCount};
+ SpvInst* result = nullptr;
+ if (m_vectorTypes.TryGetValue(key, result))
+ return result;
+ if (!inst)
+ {
+ IRBuilder builder;
+ builder.sharedBuilder = &m_sharedIRBuilder;
+ builder.setInsertInto(m_irModule->getModuleInst());
+ inst = builder.getVectorType(
+ builder.getBasicType(baseType),
+ builder.getIntValue(builder.getIntType(), elementCount));
+ }
+ result = emitInst(
+ getSection(SpvLogicalSectionID::Types),
+ inst,
+ SpvOpTypeVector,
+ kResultID,
+ inst->getElementType(),
+ (SpvWord)elementCount);
+ m_vectorTypes[key] = result;
+ return result;
+ }
+
+ void emitVarLayout(SpvInst* varInst, IRVarLayout* layout)
+ {
+ for (auto rr : layout->getOffsetAttrs())
+ {
+ UInt index = rr->getOffset();
+ UInt space = rr->getSpace();
+ switch (rr->getResourceKind())
+ {
+ case LayoutResourceKind::Uniform:
+ break;
+
+ case LayoutResourceKind::VaryingInput:
+ emitInst(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ SpvOpDecorate,
+ varInst,
+ SpvDecorationLocation,
+ (SpvWord)index);
+ emitInst(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ SpvOpDecorate,
+ varInst,
+ SpvDecorationIndex,
+ (SpvWord)space);
+ break;
+ case LayoutResourceKind::VaryingOutput:
+ emitInst(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ SpvOpDecorate,
+ varInst,
+ SpvDecorationLocation,
+ (SpvWord)index);
+ if (space)
+ {
+ emitInst(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ SpvOpDecorate,
+ varInst,
+ SpvDecorationIndex,
+ (SpvWord)space);
+ }
+ break;
+
+ case LayoutResourceKind::SpecializationConstant:
+ emitInst(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ SpvOpDecorate,
+ varInst,
+ SpvDecorationSpecId,
+ (SpvWord)index);
+ break;
+
+ case LayoutResourceKind::ConstantBuffer:
+ case LayoutResourceKind::ShaderResource:
+ case LayoutResourceKind::UnorderedAccess:
+ case LayoutResourceKind::SamplerState:
+ case LayoutResourceKind::DescriptorTableSlot:
+ emitInst(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ SpvOpDecorate,
+ varInst,
+ SpvDecorationBinding,
+ (SpvWord)index);
+ if (space)
+ {
+ emitInst(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ SpvOpDecorate,
+ varInst,
+ SpvDecorationDescriptorSet,
+ (SpvWord)space);
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ /// Emit a global parameter definition.
+ SpvInst* emitGlobalParam(IRGlobalParam* param)
+ {
+ auto layout = getVarLayout(param);
+ auto storageClass = SpvStorageClassUniform;
+ if (auto ptrType = as<IRPtrTypeBase>(param->getDataType()))
+ {
+ if (ptrType->hasAddressSpace())
+ storageClass = (SpvStorageClass)ptrType->getAddressSpace();
+ }
+ auto varInst = emitInst(
+ getSection(SpvLogicalSectionID::GlobalVariables),
+ param,
+ SpvOpVariable,
+ param->getDataType(),
+ kResultID,
+ storageClass);
+ emitVarLayout(varInst, layout);
+ return varInst;
+ }
+
+ /// Emit a global variable definition.
+ SpvInst* emitGlobalVar(IRGlobalVar* globalVar)
+ {
+ auto layout = getVarLayout(globalVar);
+ auto storageClass = SpvStorageClassUniform;
+ if (auto ptrType = as<IRPtrTypeBase>(globalVar->getDataType()))
+ {
+ if (ptrType->hasAddressSpace())
+ storageClass = (SpvStorageClass)ptrType->getAddressSpace();
+ }
+ auto varInst = emitInst(
+ getSection(SpvLogicalSectionID::GlobalVariables),
+ globalVar,
+ SpvOpVariable,
+ globalVar->getDataType(),
+ kResultID,
+ storageClass);
+ emitVarLayout(varInst, layout);
+ return varInst;
+ }
+
/// Emit the given `irFunc` to SPIR-V
SpvInst* emitFunc(IRFunc* irFunc)
{
@@ -951,9 +1282,7 @@ struct SPIRVEmitContext
//
for( auto irParam : irFunc->getParams() )
{
- emitInst(spvFunc, irParam, SpvOpFunctionParameter,
- irParam->getFullType(),
- kResultID);
+ emitParam(spvFunc, irParam);
}
// [3.32.17. Control-Flow Instructions]
@@ -992,11 +1321,13 @@ struct SPIRVEmitContext
// [3.32.17. Control-Flow Instructions]
//
// > OpPhi
- //
- // TODO: We eventually need to emit `OpPhi` instructions corresponding
- // to the parameters of any non-entry block, with operands representing
- // the values passed along incoming edges from the predecessor blocks.
-
+ if (irBlock != irFunc->getFirstBlock())
+ {
+ for (auto irParam : irBlock->getParams())
+ {
+ emitPhi(spvBlock, irParam);
+ }
+ }
for( auto irInst : irBlock->getOrdinaryInsts() )
{
// Any instructions local to the block will be emitted as children
@@ -1036,16 +1367,243 @@ struct SPIRVEmitContext
/// Emit an instruction that is local to the body of the given `parent`.
SpvInst* emitLocalInst(SpvInstParent* parent, IRInst* inst)
{
+ auto getBlockID = [=](IRBlock* block)
+ {
+ SpvInst* spvInst = nullptr;
+ m_mapIRInstToSpvInst.TryGetValue(block, spvInst);
+ SLANG_ASSERT(spvInst);
+ return getID(spvInst);
+ };
switch( inst->getOp() )
{
default:
SLANG_UNIMPLEMENTED_X("unhandled instruction opcode");
break;
+ case kIROp_Specialize:
+ return nullptr;
+ case kIROp_Var:
+ return emitVar(parent, inst);
+ case kIROp_Call:
+ return emitCall(parent, inst);
+ case kIROp_FieldAddress:
+ return emitFieldAddress(parent, as<IRFieldAddress>(inst));
+ case kIROp_FieldExtract:
+ return emitFieldExtract(parent, as<IRFieldExtract>(inst));
+ case kIROp_getElementPtr:
+ return emitGetElementPtr(parent, as<IRGetElementPtr>(inst));
+ case kIROp_getElement:
+ return emitGetElement(parent, as<IRGetElement>(inst));
+ case kIROp_Load:
+ return emitLoad(parent, as<IRLoad>(inst));
+ case kIROp_Store:
+ return emitStore(parent, as<IRStore>(inst));
+ case kIROp_swizzle:
+ return emitSwizzle(parent, as<IRSwizzle>(inst));
+ case kIROp_Construct:
+ return emitConstruct(parent, inst);
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_IRem:
+ case kIROp_FRem:
+ case kIROp_Neg:
+ case kIROp_Not:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_BitNot:
+ case kIROp_BitAnd:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_Less:
+ case kIROp_Leq:
+ case kIROp_Eql:
+ case kIROp_Neq:
+ case kIROp_Greater:
+ case kIROp_Geq:
+ case kIROp_Rsh:
+ case kIROp_Lsh:
+ return emitArithmetic(parent, inst);
+ case kIROp_ReturnVal:
+ return emitInst(
+ parent, inst, SpvOpReturnValue, as<IRReturnVal>(inst)->getVal());
+ case kIROp_ReturnVoid:
+ return emitInst(parent, inst, SpvOpReturn);
+ case kIROp_discard:
+ return emitInst(parent, inst, SpvOpKill);
+ case kIROp_unconditionalBranch:
+ return emitInst(
+ parent,
+ inst,
+ SpvOpBranch,
+ getBlockID(as<IRUnconditionalBranch>(inst)->getTargetBlock()));
+ case kIROp_loop:
+ {
+ auto loopInst = as<IRLoop>(inst);
+
+ SpvWord loopControl = 0;
+ if (auto loopControlDecoration =
+ loopInst->findDecoration<IRLoopControlDecoration>())
+ {
+ switch (loopControlDecoration->getMode())
+ {
+ case IRLoopControl::kIRLoopControl_Unroll:
+ loopControl = 0x1;
+ break;
+ case IRLoopControl::kIRLoopControl_Loop:
+ loopControl = 0x2;
+ break;
+ default:
+ break;
+ }
+ }
+ emitInst(
+ parent,
+ nullptr,
+ SpvOpLoopMerge,
+ getBlockID(loopInst->getBreakBlock()),
+ getBlockID(loopInst->getContinueBlock()),
+ loopControl);
+
+ return emitInst(parent, inst, SpvOpBranch, loopInst->getTargetBlock());
+ }
+ case kIROp_ifElse:
+ {
+ auto ifelseInst = as<IRIfElse>(inst);
+ auto afterBlockID = getBlockID(ifelseInst->getAfterBlock());
+ emitInst(
+ parent,
+ nullptr,
+ SpvOpSelectionMerge,
+ afterBlockID);
+ auto falseLabel = ifelseInst->getFalseBlock();
+ return emitInst(
+ parent,
+ inst,
+ SpvOpBranchConditional,
+ ifelseInst->getCondition(),
+ ifelseInst->getTrueBlock(),
+ falseLabel ? getID(ensureInst(falseLabel)) : afterBlockID);
+ }
+ case kIROp_Switch:
+ {
+ auto switchInst = as<IRSwitch>(inst);
+ auto mergeBlockID = getBlockID(switchInst->getBreakLabel());
+ emitInst(
+ parent,
+ nullptr,
+ SpvOpSelectionMerge, mergeBlockID);
+ return emitInstCustomOperandFunc(parent, inst, SpvOpSwitch, [&]() {
+ emitOperand(switchInst->getCondition());
+ auto defaultLabel = switchInst->getDefaultLabel();
+ emitOperand(defaultLabel ? getID(ensureInst(defaultLabel)) : mergeBlockID);
+ for (UInt c = 0; c < switchInst->getCaseCount(); c++)
+ {
+ auto value = switchInst->getCaseValue(c);
+ auto intLit = as<IRIntLit>(value);
+ SLANG_ASSERT(intLit);
+ emitOperand((SpvWord)intLit->getValue());
+ auto caseLabel = switchInst->getCaseLabel(c);
+ emitOperand(caseLabel ? getID(ensureInst(caseLabel)) : mergeBlockID);
+ }
+ });
+ }
+ case kIROp_Unreachable:
+ return emitInst(parent, inst, SpvOpUnreachable);
+ case kIROp_conditionalBranch:
+ SLANG_UNEXPECTED("Unstructured branching is not supported by SPIRV.");
+ }
+ }
- // [3.32.17. Control-Flow Instructions]
- //
- // > OpReturn
- case kIROp_ReturnVoid: return emitInst(parent, inst, SpvOpReturn);
+ SpvInst* emitLit(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_IntLit:
+ {
+ auto value = as<IRIntLit>(inst)->getValue();
+ switch (as<IRBasicType>(inst->getDataType())->getBaseType())
+ {
+ case BaseType::Int64:
+ case BaseType::UInt64:
+ return emitInst(
+ getSection(SpvLogicalSectionID::Constants),
+ inst,
+ SpvOpConstant,
+ inst->getDataType(),
+ kResultID,
+ (SpvWord)(value & 0xFFFFFFFF),
+ (SpvWord)((value >> 32) & 0xFFFFFFFF));
+ default:
+ return emitInst(
+ getSection(SpvLogicalSectionID::Constants),
+ inst,
+ SpvOpConstant,
+ inst->getDataType(),
+ kResultID,
+ (SpvWord)value);
+ }
+ }
+ case kIROp_FloatLit:
+ {
+ auto value = as<IRConstant>(inst)->value.floatVal;
+ switch (as<IRBasicType>(inst->getDataType())->getBaseType())
+ {
+ case BaseType::Half:
+ return emitInst(
+ getSection(SpvLogicalSectionID::Constants),
+ inst,
+ SpvOpConstant,
+ inst->getDataType(),
+ kResultID,
+ (SpvWord)(FloatToHalf((float)value)));
+ case BaseType::Float:
+ return emitInst(
+ getSection(SpvLogicalSectionID::Constants),
+ inst,
+ SpvOpConstant,
+ inst->getDataType(),
+ kResultID,
+ (SpvWord)(FloatAsInt((float)value)));
+ case BaseType::Double:
+ {
+ auto ival = DoubleAsInt64(value);
+ return emitInst(
+ getSection(SpvLogicalSectionID::Constants),
+ inst,
+ SpvOpConstant,
+ inst->getDataType(),
+ kResultID,
+ (SpvWord)(ival&0xFFFFFFFF),
+ (SpvWord)(ival>>32));
+ }
+ default:
+ return nullptr;
+ }
+ }
+ case kIROp_BoolLit:
+ {
+ if (as<IRBoolLit>(inst)->getValue())
+ {
+ return emitInst(
+ getSection(SpvLogicalSectionID::Constants),
+ inst,
+ SpvOpConstantTrue,
+ inst->getDataType(),
+ kResultID);
+ }
+ else
+ {
+ return emitInst(
+ getSection(SpvLogicalSectionID::Constants),
+ inst,
+ SpvOpConstantFalse,
+ inst->getDataType(),
+ kResultID);
+ }
+ }
+ default:
+ return nullptr;
}
}
@@ -1184,24 +1742,655 @@ struct SPIRVEmitContext
}
}
- SPIRVEmitContext(IRModule* module) :
- m_irModule(module),
- m_memoryArena(2048)
+ SpvInst* emitBuiltinSystemVal(SpvInstParent* parent, IRInst* inst, SpvBuiltIn builtinVal)
+ {
+ IRBuilder builder;
+ builder.sharedBuilder = &m_sharedIRBuilder;
+ builder.setInsertBefore(inst);
+
+ auto ptrIRType = builder.getPtrType(inst->getDataType());
+ auto varInst = emitInst(parent, inst, SpvOpVariable, ptrIRType, kResultID);
+ emitInst(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ SpvOpDecorate,
+ varInst,
+ SpvDecorationBuiltIn,
+ builtinVal);
+ return varInst;
+ }
+
+ SpvInst* emitParam(SpvInstParent* parent, IRInst* inst)
+ {
+ if (auto layout = getVarLayout(inst))
+ {
+ if (auto systemValueAttr = layout->findAttr<IRSystemValueSemanticAttr>())
+ {
+ String semanticName = systemValueAttr->getName();
+ semanticName = semanticName.toLower();
+ if (semanticName == "sv_dispatchthreadid")
+ {
+ return emitBuiltinSystemVal(parent, inst, SpvBuiltInGlobalInvocationId);
+ }
+ }
+ }
+ return emitInst(parent, inst, SpvOpFunctionParameter, inst->getFullType(), kResultID);
+ }
+
+ SpvInst* emitVar(SpvInstParent* parent, IRInst* inst)
+ {
+ SpvWord storageClass = SpvStorageClassFunction;
+ auto rate = inst->getFullType()->getRate();
+ if (rate)
+ {
+ switch (rate->getOp())
+ {
+ case kIROp_GroupSharedRate:
+ storageClass = SpvStorageClassWorkgroup;
+ break;
+ default:
+ break;
+ }
+ }
+ return emitInst(parent, inst, SpvOpVariable, inst->getFullType(), kResultID, storageClass);
+ }
+
+ /// Cached `IRParam` indices in an `IRBlock`. For use in `getParamIndexInBlock`.
+ struct BlockParamIndexInfo : public RefObject
+ {
+ Dictionary<IRParam*, int> mapParamToIndex;
+ };
+ Dictionary<IRBlock*, RefPtr<BlockParamIndexInfo>> m_mapIRBlockToParamIndexInfo;
+
+ /// Returns the index of an `IRParam` inside a `IRBlock`.
+ /// The results are cached in `m_mapIRBlockToParamIndexInfo` to avoid linear search.
+ int getParamIndexInBlock(IRBlock* block, IRParam* paramInst)
+ {
+ RefPtr<BlockParamIndexInfo> info;
+ int result = -1;
+ if (m_mapIRBlockToParamIndexInfo.TryGetValue(block, info))
+ {
+ info->mapParamToIndex.TryGetValue(paramInst, result);
+ SLANG_ASSERT(result != -1);
+ return result;
+ }
+ info = new BlockParamIndexInfo();
+ int paramIndex = 0;
+ for (auto param : block->getParams())
+ {
+ info->mapParamToIndex[param] = paramIndex;
+ if (param == paramInst)
+ result = paramIndex;
+ paramIndex++;
+ }
+ m_mapIRBlockToParamIndexInfo[block] = info;
+ SLANG_ASSERT(result != -1);
+ return result;
+ }
+
+ SpvInst* emitPhi(SpvInstParent* parent, IRParam* inst)
+ {
+ // An `IRParam` in an ordinary `IRBlock` represents a phi value.
+ // We can translate them directly to SPIRV's `Phi` instruction.
+ // In order to do that, we need to figure out the source values
+ // of this `IRParam`, which can be done by looking at the users
+ // of current `IRBlock`.
+
+ // First, we find the index of this param.
+ IRBlock* block = as<IRBlock>(inst->getParent());
+ SLANG_ASSERT(block);
+ int paramIndex = getParamIndexInBlock(block, inst);
+
+ // Emit a Phi instruction.
+ return emitInstCustomOperandFunc(parent, inst, SpvOpPhi, [&]() {
+ emitOperand(inst->getFullType());
+ emitOperand(kResultID);
+ // Find phi arguments from incoming branch instructions that target `block`.
+ for (auto use = block->firstUse; use; use = use->nextUse)
+ {
+ auto branchInst = use->getUser();
+ UInt argStartIndex = 0;
+ switch (branchInst->getOp())
+ {
+ case kIROp_unconditionalBranch:
+ argStartIndex = 1;
+ break;
+ case kIROp_loop:
+ argStartIndex = 3;
+ break;
+ default:
+ // A phi argument can only come from an unconditional branch inst.
+ // Other uses are not relavent so we should skip.
+ continue;
+ }
+ SLANG_ASSERT(argStartIndex + paramIndex < branchInst->getOperandCount());
+ auto valueInst = branchInst->getOperand(argStartIndex + paramIndex);
+ emitOperand(valueInst);
+ auto sourceBlock = as<IRBlock>(branchInst->getParent());
+ SLANG_ASSERT(sourceBlock);
+ emitOperand(getIRInstSpvID(sourceBlock));
+ }
+ });
+ }
+
+ SpvInst* emitCall(SpvInstParent* parent, IRInst* inst)
+ {
+ auto funcValue = inst->getOperand(0);
+
+ // Does this function declare any requirements.
+ handleRequiredCapabilities(funcValue);
+
+ // We want to detect any call to an intrinsic operation, and inline
+ // the SPIRV snippet directly at the call site.
+ if (auto targetIntrinsic = Slang::findBestTargetIntrinsicDecoration(
+ funcValue, m_targetRequest->getTargetCaps()))
+ {
+ return emitIntrinsicCallExpr(parent, static_cast<IRCall*>(inst), targetIntrinsic);
+ }
+ else
+ {
+ return emitInst(
+ parent, inst, SpvOpFunctionCall, inst->getFullType(), kResultID, OperandsOf(inst));
+ }
+ }
+
+ SpvInst* emitIntrinsicCallExpr(
+ SpvInstParent* parent,
+ IRCall* inst,
+ IRTargetIntrinsicDecoration* intrinsic)
+ {
+ SpvSnippet* snippet = getParsedSpvSnippet(intrinsic);
+ SpvSnippetEmitContext context;
+ context.resultType = ensureInst(inst->getFullType());
+ for (SlangUInt i = 0; i < inst->getArgCount(); i++)
+ {
+ auto argInst = ensureInst(inst->getArg(i));
+ if (argInst)
+ {
+ context.argumentIds.add(getID(argInst));
+ }
+ else
+ {
+ context.argumentIds.add(0xFFFFFFFF);
+ }
+ }
+ // A SPIRV snippet may refer to the result type of this inst with a
+ // different storage-class qualifier. We need to pre-create these
+ // storage-class-qualified result pointer types so they can be used
+ // during inlining of the snippet.
+ if (auto oldPtrType = as<IRPtrTypeBase>(inst->getDataType()))
+ {
+ for (auto storageClass : snippet->usedResultTypeStorageClasses)
+ {
+ IRBuilder builder;
+ builder.sharedBuilder = &m_sharedIRBuilder;
+ builder.setInsertBefore(inst);
+ auto newPtrType = builder.getPtrType(
+ oldPtrType->getOp(), oldPtrType->getValueType(), storageClass);
+ context.qualifiedResultTypes[storageClass] = newPtrType;
+ }
+ }
+ return emitSpvSnippet(parent, inst, context, snippet);
+ }
+
+ SpvInst* emitSpvSnippet(
+ SpvInstParent* parent,
+ IRCall* inst,
+ const SpvSnippetEmitContext& context,
+ SpvSnippet* snippet)
+ {
+ ShortList<SpvInst*> emittedInsts;
+ for (Index i = 0; i < snippet->instructions.getCount(); i++)
+ {
+ auto& spvSnippetInst = snippet->instructions[i];
+ InstConstructScope scopeInst(this, (SpvOp)spvSnippetInst.opCode, nullptr);
+ SpvInst* spvInst = scopeInst;
+ for (auto operand : spvSnippetInst.operands)
+ {
+ switch (operand.type)
+ {
+ case SpvSnippet::ASMOperandType::SpvWord:
+ emitOperand((SpvWord)operand.content);
+ break;
+ case SpvSnippet::ASMOperandType::ObjectReference:
+ SLANG_ASSERT(
+ operand.content >= 0 && operand.content < context.argumentIds.getCount());
+ emitOperand(context.argumentIds[operand.content]);
+ break;
+ case SpvSnippet::ASMOperandType::ResultId:
+ emitOperand(kResultID);
+ break;
+ case SpvSnippet::ASMOperandType::ResultTypeId:
+ if (operand.content != -1)
+ {
+ emitOperand(context.qualifiedResultTypes[(SpvStorageClass)operand.content]
+ .GetValue());
+ }
+ else
+ {
+ emitOperand(context.resultType);
+ }
+ break;
+ case SpvSnippet::ASMOperandType::InstReference:
+ SLANG_ASSERT(operand.content >= 0 && operand.content < emittedInsts.getCount());
+ emitOperand(getID(emittedInsts[operand.content]));
+ break;
+ }
+ }
+ parent->addInst(spvInst);
+ emittedInsts.add(spvInst);
+ }
+ auto resultInst = emittedInsts.getLast();
+ registerInst(inst, resultInst);
+ return resultInst;
+ }
+
+ struct StructTypeInfo : public RefObject
+ {
+ Dictionary<IRStructKey*, Index> structFieldIndices;
+ };
+
+ Dictionary<IRStructType*, RefPtr<StructTypeInfo>> m_structTypeInfos;
+
+ RefPtr<StructTypeInfo> createStructTypeInfo(IRStructType* structType)
+ {
+ RefPtr<StructTypeInfo> typeInfo = new StructTypeInfo();
+ Index index = 0;
+ for (auto field : structType->getFields())
+ {
+ typeInfo->structFieldIndices[field->getKey()] = index;
+ index++;
+ }
+ return typeInfo;
+ }
+ Index getStructFieldId(IRStructType* structType, IRStructKey* structFieldKey)
+ {
+ RefPtr<StructTypeInfo> info;
+ if (!m_structTypeInfos.TryGetValue(structType, info))
+ {
+ info = createStructTypeInfo(structType);
+ m_structTypeInfos[structType] = info;
+ }
+ Index fieldIndex = -1;
+ info->structFieldIndices.TryGetValue(structFieldKey, fieldIndex);
+ SLANG_ASSERT(fieldIndex != -1);
+ return fieldIndex;
+ }
+
+ SpvInst* emitFieldAddress(SpvInstParent* parent, IRFieldAddress* fieldAddress)
+ {
+ IRBuilder builder;
+ builder.sharedBuilder = &m_sharedIRBuilder;
+ builder.setInsertBefore(fieldAddress);
+
+ auto base = fieldAddress->getBase();
+ SpvWord baseId = 0;
+ IRStructType* baseStructType = nullptr;
+
+ if (auto ptrLikeType = as<IRPointerLikeType>(base->getDataType()))
+ {
+ baseStructType = as<IRStructType>(ptrLikeType->getElementType());
+ baseId = getID(ensureInst(base));
+ }
+ else if (auto ptrType = as<IRPtrTypeBase>(base->getDataType()))
+ {
+ baseStructType = as<IRStructType>(ptrType->getValueType());
+ baseId = getID(ensureInst(base));
+ }
+ else
+ {
+ baseStructType = as<IRStructType>(base->getDataType());
+
+ auto structPtrType = builder.getPtrType(baseStructType);
+ auto varInst = emitInst(
+ parent, nullptr, SpvOpVariable, structPtrType, kResultID, SpvStorageClassFunction);
+ emitInst(parent, nullptr, SpvOpStore, varInst, base);
+ baseId = getID(varInst);
+ }
+ SLANG_ASSERT(baseStructType && "field_address require base to be a struct.");
+ auto fieldId = emitConstant(
+ getStructFieldId(baseStructType, as<IRStructKey>(fieldAddress->getField())),
+ builder.getIntType());
+ return emitInst(
+ parent,
+ fieldAddress,
+ SpvOpAccessChain,
+ fieldAddress->getFullType(),
+ kResultID,
+ baseId,
+ fieldId);
+ }
+
+ SpvInst* emitFieldExtract(SpvInstParent* parent, IRFieldExtract* inst)
+ {
+ IRBuilder builder;
+ builder.sharedBuilder = &m_sharedIRBuilder;
+ builder.setInsertBefore(inst);
+
+ IRStructType* baseStructType = as<IRStructType>(inst->getBase()->getDataType());
+ SLANG_ASSERT(baseStructType && "field_extract require base to be a struct.");
+ auto fieldId = emitConstant(
+ getStructFieldId(baseStructType, as<IRStructKey>(inst->getField())),
+ builder.getIntType());
+
+ return emitInst(
+ parent,
+ inst,
+ SpvOpCompositeExtract,
+ inst->getDataType(),
+ kResultID,
+ inst->getBase(),
+ fieldId);
+ }
+
+ SpvInst* emitGetElementPtr(SpvInstParent* parent, IRGetElementPtr* inst)
+ {
+ auto base = inst->getBase();
+ SpvWord baseId = 0;
+ IRArrayType* baseArrayType = nullptr;
+
+ if (auto ptrLikeType = as<IRPointerLikeType>(base->getDataType()))
+ {
+ baseArrayType = as<IRArrayType>(ptrLikeType->getElementType());
+ baseId = getID(ensureInst(base));
+ }
+ else if (auto ptrType = as<IRPtrTypeBase>(base->getDataType()))
+ {
+ baseArrayType = as<IRArrayType>(ptrType->getValueType());
+ baseId = getID(ensureInst(base));
+ }
+ else
+ {
+ SLANG_ASSERT(!"invalid IR: base of getElementPtr must be a pointer.");
+ }
+ SLANG_ASSERT(baseArrayType && "getElementPtr require base to be an array.");
+ return emitInst(
+ parent,
+ inst,
+ SpvOpAccessChain,
+ inst->getFullType(),
+ kResultID,
+ baseId,
+ inst->getIndex());
+ }
+
+ SpvInst* emitGetElement(SpvInstParent* parent, IRGetElement* inst)
+ {
+ auto base = inst->getBase();
+ SpvWord baseId = 0;
+ IRArrayType* baseArrayType = nullptr;
+
+ if (auto ptrLikeType = as<IRPointerLikeType>(base->getDataType()))
+ {
+ baseArrayType = as<IRArrayType>(ptrLikeType->getElementType());
+ baseId = getID(ensureInst(base));
+ }
+ else if (auto ptrType = as<IRPtrTypeBase>(base->getDataType()))
+ {
+ baseArrayType = as<IRArrayType>(ptrType->getValueType());
+ baseId = getID(ensureInst(base));
+ }
+ else
+ {
+ SLANG_ASSERT(!"invalid IR: base of getElement must be a pointer.");
+ }
+ SLANG_ASSERT(baseArrayType && "getElement require base to be an array.");
+
+ IRBuilder builder;
+ builder.sharedBuilder = &m_sharedIRBuilder;
+ builder.setInsertBefore(inst);
+
+ auto ptr = emitInst(
+ parent,
+ nullptr,
+ SpvOpAccessChain,
+ builder.getPtrType(inst->getFullType()),
+ kResultID,
+ baseId,
+ inst->getIndex());
+ return emitInst(parent, inst, SpvOpLoad, inst->getFullType(), kResultID, ptr);
+ }
+
+ SpvInst* emitLoad(SpvInstParent* parent, IRLoad* inst)
+ {
+ return emitInst(parent, inst, SpvOpLoad, inst->getDataType(), kResultID, inst->getPtr());
+ }
+
+ SpvInst* emitStore(SpvInstParent* parent, IRStore* inst)
+ {
+ return emitInst(parent, inst, SpvOpStore, inst->getPtr(), inst->getVal());
+ }
+
+ SpvInst* emitSwizzle(SpvInstParent* parent, IRSwizzle* inst)
+ {
+ return emitInstCustomOperandFunc(parent, inst, SpvOpVectorShuffle, [&]() {
+ emitOperand(inst->getDataType());
+ emitOperand(kResultID);
+ emitOperand(inst->getBase());
+ emitOperand(inst->getBase());
+ for (UInt i = 0; i < inst->getElementCount(); i++)
+ {
+ auto index = as<IRIntLit>(inst->getElementIndex(i));
+ emitOperand((SpvWord)index->getValue());
+ }
+ });
+ }
+
+ SpvInst* emitConstruct(SpvInstParent* parent, IRInst* inst)
+ {
+ if (as<IRBasicType>(inst->getDataType()))
+ {
+ if (inst->getOperandCount() == 1)
+ {
+ if (inst->getDataType() == inst->getOperand(0)->getDataType())
+ return emitInst(parent, inst, SpvOpCopyObject, kResultID, inst->getOperand(0));
+ else
+ return emitInst(parent, inst, SpvOpBitcast, inst->getDataType(), kResultID, inst->getOperand(0));
+ }
+ else
+ {
+ SLANG_ASSERT(!"spirv emit: unsupported Construct inst.");
+ return nullptr;
+ }
+ }
+ else
+ {
+ return emitInst(
+ parent,
+ inst,
+ SpvOpCompositeConstruct,
+ inst->getDataType(),
+ kResultID,
+ OperandsOf(inst));
+ }
+ }
+
+ bool isSignedType(IRBasicType* basicType)
+ {
+ switch (basicType->getBaseType())
+ {
+ case BaseType::Float:
+ case BaseType::Double:
+ return true;
+ case BaseType::Int:
+ case BaseType::Int16:
+ case BaseType::Int64:
+ case BaseType::Int8:
+ return true;
+ default:
+ return false;
+ }
+ }
+
+ SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst)
+ {
+ IRType* elementType = inst->getDataType();
+ if (auto vectorType = as<IRVectorType>(inst->getDataType()))
+ {
+ elementType = vectorType->getElementType();
+ }
+ else if (auto matrixType = as<IRMatrixType>(inst->getDataType()))
+ {
+ //TODO: implement.
+ SLANG_ASSERT(!"unimplemented: matrix arithemetic");
+ }
+ IRBasicType* basicType = as<IRBasicType>(elementType);
+ bool isFloatingPoint = false;
+ bool isBool = false;
+ switch (basicType->getBaseType())
+ {
+ case BaseType::Float:
+ case BaseType::Double:
+ isFloatingPoint = true;
+ break;
+ case BaseType::Bool:
+ isBool = true;
+ default:
+ break;
+ }
+ SpvOp opCode = SpvOpUndef;
+ bool isSigned = isSignedType(basicType);
+ switch (inst->getOp())
+ {
+ case kIROp_Add:
+ opCode = isFloatingPoint ? SpvOpFAdd : SpvOpIAdd;
+ break;
+ case kIROp_Sub:
+ opCode = isFloatingPoint ? SpvOpFSub : SpvOpISub;
+ break;
+ case kIROp_Mul:
+ opCode = isFloatingPoint ? SpvOpFMul : SpvOpIMul;
+ break;
+ case kIROp_Div:
+ opCode = isFloatingPoint ? SpvOpFDiv : isSigned ? SpvOpSDiv : SpvOpUDiv;
+ break;
+ case kIROp_IRem:
+ opCode = isSigned ? SpvOpSRem : SpvOpUMod;
+ break;
+ case kIROp_FRem:
+ opCode = SpvOpFRem;
+ break;
+ case kIROp_Less:
+ opCode = isFloatingPoint ? SpvOpFOrdLessThan
+ : isSigned ? SpvOpSLessThan : SpvOpULessThan;
+ break;
+ case kIROp_Leq:
+ opCode = isFloatingPoint ? SpvOpFOrdLessThanEqual
+ : isSigned ? SpvOpSLessThanEqual : SpvOpULessThanEqual;
+ break;
+ case kIROp_Eql:
+ opCode = isFloatingPoint ? SpvOpFOrdEqual : isBool ? SpvOpLogicalEqual : SpvOpIEqual;
+ break;
+ case kIROp_Neq:
+ opCode = isFloatingPoint ? SpvOpFOrdNotEqual
+ : isBool ? SpvOpLogicalNotEqual : SpvOpINotEqual;
+ break;
+ case kIROp_Geq:
+ opCode = isFloatingPoint ? SpvOpFOrdGreaterThanEqual
+ : isSigned ? SpvOpSGreaterThanEqual : SpvOpUGreaterThanEqual;
+ break;
+ case kIROp_Greater:
+ opCode = isFloatingPoint ? SpvOpFOrdGreaterThan
+ : isSigned ? SpvOpSGreaterThan : SpvOpUGreaterThan;
+ break;
+ case kIROp_Neg:
+ opCode = isFloatingPoint ? SpvOpFNegate : SpvOpSNegate;
+ break;
+ case kIROp_And:
+ opCode = SpvOpLogicalAnd;
+ break;
+ case kIROp_Or:
+ opCode = SpvOpLogicalOr;
+ break;
+ case kIROp_Not:
+ opCode = SpvOpLogicalNot;
+ break;
+ case kIROp_BitAnd:
+ opCode = SpvOpBitwiseAnd;
+ break;
+ case kIROp_BitOr:
+ opCode = SpvOpBitwiseOr;
+ break;
+ case kIROp_BitXor:
+ opCode = SpvOpBitwiseXor;
+ break;
+ case kIROp_BitNot:
+ opCode = SpvOpBitReverse;
+ break;
+ case kIROp_Rsh:
+ opCode = isSigned ? SpvOpShiftRightArithmetic : SpvOpShiftRightLogical;
+ break;
+ case kIROp_Lsh:
+ opCode = SpvOpShiftLeftLogical;
+ break;
+ default:
+ SLANG_ASSERT(!"unknown arithmetic opcode");
+ break;
+ }
+ return emitInst(parent, inst, opCode, inst->getDataType(), kResultID, OperandsOf(inst));
+ }
+
+ OrderedHashSet<SpvCapability> m_capabilities;
+
+ void requireSPIRVCapability(SpvCapability capability)
+ {
+ if (m_capabilities.Add(capability))
+ {
+ emitInst(
+ getSection(SpvLogicalSectionID::Capabilities),
+ nullptr,
+ SpvOpCapability,
+ capability);
+ }
+ }
+
+ void handleRequiredCapabilitiesImpl(IRInst* inst)
+ {
+ // TODO: declare required SPV capabilities.
+
+ for (auto decoration : inst->getDecorations())
+ {
+ switch (decoration->getOp())
+ {
+ default:
+ break;
+
+ case kIROp_RequireGLSLExtensionDecoration:
+ {
+ break;
+ }
+ case kIROp_RequireGLSLVersionDecoration:
+ {
+ break;
+ }
+ case kIROp_RequireSPIRVVersionDecoration:
+ {
+ break;
+ }
+ }
+ }
+ }
+
+ SPIRVEmitContext(IRModule* module, TargetRequest* target, DiagnosticSink* sink)
+ : SPIRVEmitSharedContext(module, target)
+ , m_irModule(module)
+ , m_sink(sink)
+ , m_memoryArena(2048)
{
}
};
SlangResult emitSPIRVFromIR(
BackEndCompileRequest* compileRequest,
+ TargetRequest* targetRequest,
IRModule* irModule,
const List<IRFunc*>& irEntryPoints,
List<uint8_t>& spirvOut)
{
- SLANG_UNUSED(compileRequest);
-
spirvOut.clear();
- SPIRVEmitContext context(irModule);
+ SPIRVEmitContext context(irModule, targetRequest, compileRequest->getSink());
+ legalizeIRForSPIRV(&context, irModule, compileRequest->getSink());
context.emitFrontMatter();
for (auto irEntryPoint : irEntryPoints)
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 352a27746..3da19cef1 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -916,6 +916,7 @@ SlangResult emitEntryPointsSourceFromIR(
SlangResult emitSPIRVFromIR(
BackEndCompileRequest* compileRequest,
+ TargetRequest* targetRequest,
IRModule* irModule,
const List<IRFunc*>& irEntryPoints,
List<uint8_t>& spirvOut);
@@ -947,11 +948,7 @@ SlangResult emitSPIRVForEntryPointsDirectly(
auto irModule = linkedIR.module;
auto irEntryPoints = linkedIR.entryPoints;
- emitSPIRVFromIR(
- compileRequest,
- irModule,
- irEntryPoints,
- spirvOut);
+ emitSPIRVFromIR(compileRequest, targetRequest, irModule, irEntryPoints, spirvOut);
return SLANG_OK;
}
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 267866b1b..25313d2f5 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -1904,6 +1904,7 @@ struct IRBuilder
IRInOutType* getInOutType(IRType* valueType);
IRRefType* getRefType(IRType* valueType);
IRPtrTypeBase* getPtrType(IROp op, IRType* valueType);
+ IRPtrType* getPtrType(IROp op, IRType* valueType, IRIntegerValue addressSpace);
IRArrayTypeBase* getArrayTypeBase(
IROp op,
@@ -2734,6 +2735,14 @@ IRTargetSpecificDecoration* findBestTargetDecoration(
IRInst* val,
CapabilityAtom targetCapabilityAtom);
+inline IRTargetIntrinsicDecoration* findBestTargetIntrinsicDecoration(
+ IRInst* inInst,
+ CapabilitySet const& targetCaps)
+{
+ return as<IRTargetIntrinsicDecoration>(findBestTargetDecoration(inInst, targetCaps));
+}
+
+
}
#endif
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
new file mode 100644
index 000000000..f7fc53bdb
--- /dev/null
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -0,0 +1,258 @@
+// slang-ir-spirv-legalize.cpp
+#include "slang-ir-spirv-legalize.h"
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-emit-base.h"
+#include "slang-glsl-extension-tracker.h"
+
+namespace Slang
+{
+
+//
+// Legalization of IR for direct SPIRV emit.
+//
+
+struct StorageClassPropagationContext : public SourceEmitterBase
+{
+ SPIRVEmitSharedContext* m_sharedContext;
+
+ IRModule* m_module;
+ // We will use a single work list of instructions that need
+ // to be considered for specialization or simplification,
+ // whether generic, existential, etc.
+ //
+ OrderedHashSet<IRInst*> workList;
+
+ void addToWorkList(IRInst* inst)
+ {
+ if (workList.Add(inst))
+ {
+ addUsersToWorkList(inst);
+ }
+ }
+
+ void addUsersToWorkList(IRInst* inst)
+ {
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+
+ addToWorkList(user);
+ }
+ }
+
+ StorageClassPropagationContext(SPIRVEmitSharedContext* sharedContext, IRModule* module)
+ : m_sharedContext(sharedContext), m_module(module)
+ {
+ }
+
+ void processGlobalParam(IRGlobalParam* inst) { processGlobalVar(inst); }
+
+ void processGlobalVar(IRInst* inst)
+ {
+ auto oldPtrType = as<IRPtrTypeBase>(inst->getDataType());
+ if (!oldPtrType)
+ return;
+
+ // If the pointer type is already qualified with address spaces (such as
+ // lowered pointer type from a `HLSLStructuredBufferType`), make no
+ // further modifications.
+ if (oldPtrType->hasAddressSpace())
+ {
+ addUsersToWorkList(inst);
+ return;
+ }
+
+ auto varLayout = getVarLayout(inst);
+ if (!varLayout)
+ return;
+
+ SpvStorageClass storageClass = SpvStorageClassPrivate;
+ for (auto rr : varLayout->getOffsetAttrs())
+ {
+ switch (rr->getResourceKind())
+ {
+ case LayoutResourceKind::Uniform:
+ case LayoutResourceKind::ShaderResource:
+ case LayoutResourceKind::DescriptorTableSlot:
+ storageClass = SpvStorageClassUniform;
+ break;
+ case LayoutResourceKind::VaryingInput:
+ storageClass = SpvStorageClassInput;
+ break;
+ case LayoutResourceKind::VaryingOutput:
+ storageClass = SpvStorageClassOutput;
+ break;
+ case LayoutResourceKind::UnorderedAccess:
+ storageClass = SpvStorageClassStorageBuffer;
+ break;
+ case LayoutResourceKind::PushConstantBuffer:
+ storageClass = SpvStorageClassPushConstant;
+ break;
+ default:
+ break;
+ }
+ }
+ auto rate = inst->getRate();
+ if (as<IRGroupSharedRate>(rate))
+ {
+ storageClass = SpvStorageClassWorkgroup;
+ }
+ IRBuilder builder;
+ builder.sharedBuilder = &m_sharedContext->m_sharedIRBuilder;
+ builder.setInsertBefore(inst);
+ auto newPtrType =
+ builder.getPtrType(oldPtrType->getOp(), oldPtrType->getValueType(), storageClass);
+ inst->setFullType(newPtrType);
+ addUsersToWorkList(inst);
+ return;
+ }
+
+ void processCall(IRCall* inst)
+ {
+ auto funcValue = inst->getOperand(0);
+ if (auto targetIntrinsic = Slang::findBestTargetIntrinsicDecoration(
+ funcValue, m_sharedContext->m_targetRequest->getTargetCaps()))
+ {
+ SpvSnippet* snippet = m_sharedContext->getParsedSpvSnippet(targetIntrinsic);
+ if (!snippet)
+ return;
+ if (snippet->resultStorageClass != SpvStorageClassMax)
+ {
+ auto ptrType = as<IRPtrTypeBase>(inst->getDataType());
+ if (!ptrType)
+ return;
+ IRBuilder builder;
+ builder.sharedBuilder = &m_sharedContext->m_sharedIRBuilder;
+ builder.setInsertBefore(inst);
+ auto qualPtrType = builder.getPtrType(
+ ptrType->getOp(), ptrType->getValueType(), snippet->resultStorageClass);
+ List<IRInst*> args;
+ for (UInt i = 0; i < inst->getArgCount(); i++)
+ args.add(inst->getArg(i));
+ auto newCall = builder.emitCallInst(qualPtrType, funcValue, args);
+ inst->replaceUsesWith(newCall);
+ inst->removeAndDeallocate();
+ addUsersToWorkList(newCall);
+ }
+ }
+ }
+
+ void processGetElementPtr(IRGetElementPtr* inst)
+ {
+ if (auto ptrType = as<IRPtrTypeBase>(inst->getBase()->getDataType()))
+ {
+ if (!ptrType->hasAddressSpace())
+ return;
+ auto oldResultType = as<IRPtrTypeBase>(inst->getDataType());
+ if (oldResultType->getAddressSpace() != ptrType->getAddressSpace())
+ {
+ IRBuilder builder;
+ builder.sharedBuilder = &m_sharedContext->m_sharedIRBuilder;
+ builder.setInsertBefore(inst);
+ auto newPtrType = builder.getPtrType(
+ oldResultType->getOp(),
+ oldResultType->getValueType(),
+ ptrType->getAddressSpace());
+ auto newInst =
+ builder.emitElementAddress(newPtrType, inst->getBase(), inst->getIndex());
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ addUsersToWorkList(newInst);
+ }
+ }
+ }
+
+ void processFieldAddress(IRFieldAddress* inst)
+ {
+ if (auto ptrType = as<IRPtrTypeBase>(inst->getBase()->getDataType()))
+ {
+ if (!ptrType->hasAddressSpace())
+ return;
+ auto oldResultType = as<IRPtrTypeBase>(inst->getDataType());
+ if (oldResultType->getAddressSpace() != ptrType->getAddressSpace())
+ {
+ IRBuilder builder;
+ builder.sharedBuilder = &m_sharedContext->m_sharedIRBuilder;
+ builder.setInsertBefore(inst);
+ auto newPtrType = builder.getPtrType(
+ oldResultType->getOp(),
+ oldResultType->getValueType(),
+ ptrType->getAddressSpace());
+ auto newInst =
+ builder.emitFieldAddress(newPtrType, inst->getBase(), inst->getField());
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ addUsersToWorkList(newInst);
+ }
+ }
+ }
+
+ void processStructuredBufferType(IRHLSLStructuredBufferTypeBase* inst)
+ {
+ IRBuilder builder;
+ builder.sharedBuilder = &m_sharedContext->m_sharedIRBuilder;
+ builder.setInsertBefore(inst);
+ auto arrayType = builder.getUnsizedArrayType(inst->getElementType());
+ auto ptrType = builder.getPtrType(kIROp_PtrType, arrayType, SpvStorageClassStorageBuffer);
+ inst->replaceUsesWith(ptrType);
+ inst->removeAndDeallocate();
+ addUsersToWorkList(ptrType);
+ }
+
+ void propagate()
+ {
+ addToWorkList(m_module->getModuleInst());
+ while (workList.Count() != 0)
+ {
+ IRInst* inst = workList.getLast();
+ workList.removeLast();
+ switch (inst->getOp())
+ {
+ case kIROp_GlobalParam:
+ processGlobalParam(as<IRGlobalParam>(inst));
+ break;
+ case kIROp_GlobalVar:
+ processGlobalVar(as<IRGlobalVar>(inst));
+ break;
+ case kIROp_Call:
+ processCall(as<IRCall>(inst));
+ break;
+ case kIROp_getElementPtr:
+ processGetElementPtr(as<IRGetElementPtr>(inst));
+ break;
+ case kIROp_FieldAddress:
+ processFieldAddress(as<IRFieldAddress>(inst));
+ break;
+ case kIROp_HLSLStructuredBufferType:
+ case kIROp_HLSLRWStructuredBufferType:
+ processStructuredBufferType(as<IRHLSLStructuredBufferTypeBase>(inst));
+ break;
+ default:
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
+ {
+ addToWorkList(child);
+ }
+ break;
+ }
+ }
+ }
+};
+
+void propagateStorageClass(SPIRVEmitSharedContext* sharedContext, IRModule* module)
+{
+ StorageClassPropagationContext context(sharedContext, module);
+ context.propagate();
+}
+
+void legalizeIRForSPIRV(
+ SPIRVEmitSharedContext* context,
+ IRModule* module,
+ DiagnosticSink* sink)
+{
+ SLANG_UNUSED(sink);
+ propagateStorageClass(context, module);
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-spirv-legalize.h b/source/slang/slang-ir-spirv-legalize.h
new file mode 100644
index 000000000..bf43430d8
--- /dev/null
+++ b/source/slang/slang-ir-spirv-legalize.h
@@ -0,0 +1,45 @@
+// slang-ir-spirv-legalize.h
+#pragma once
+#include "../core/slang-basic.h"
+#include "slang-ir-spirv-snippet.h"
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+
+class DiagnosticSink;
+
+struct IRFunc;
+struct IRModule;
+class TargetRequest;
+
+struct SPIRVEmitSharedContext
+{
+ SharedIRBuilder m_sharedIRBuilder;
+ Dictionary<IRTargetIntrinsicDecoration*, RefPtr<SpvSnippet>> m_parsedSpvSnippets;
+ TargetRequest* m_targetRequest;
+
+ SPIRVEmitSharedContext(IRModule* module, TargetRequest* target)
+ : m_sharedIRBuilder(module)
+ , m_targetRequest(target)
+ {}
+
+ SpvSnippet* getParsedSpvSnippet(IRTargetIntrinsicDecoration* intrinsic)
+ {
+ RefPtr<SpvSnippet> snippet;
+ if (m_parsedSpvSnippets.TryGetValue(intrinsic, snippet))
+ {
+ return snippet.Ptr();
+ }
+ snippet = SpvSnippet::parse(intrinsic->getDefinition());
+ m_parsedSpvSnippets[intrinsic] = snippet;
+ return snippet;
+ }
+};
+
+void legalizeIRForSPIRV(
+ SPIRVEmitSharedContext* context,
+ IRModule* module,
+ DiagnosticSink* sink);
+
+}
diff --git a/source/slang/slang-ir-spirv-snippet.cpp b/source/slang/slang-ir-spirv-snippet.cpp
new file mode 100644
index 000000000..4083f100d
--- /dev/null
+++ b/source/slang/slang-ir-spirv-snippet.cpp
@@ -0,0 +1,124 @@
+// slang-ir-spirv-snippet.cpp
+
+#include"slang-ir-spirv-snippet.h"
+#include "../core/slang-token-reader.h"
+
+namespace Slang
+{
+static SpvStorageClass translateStorageClass(String name)
+{
+ if (name == "Uniform")
+ {
+ return SpvStorageClassUniform;
+ }
+ else if (name == "StorageBuffer")
+ {
+ return SpvStorageClassStorageBuffer;
+ }
+ return (SpvStorageClass)-1;
+}
+
+RefPtr<SpvSnippet> SpvSnippet::parse(UnownedStringSlice definition)
+{
+ RefPtr<SpvSnippet> snippet = new SpvSnippet();
+ try
+ {
+ Dictionary<String, int> mapInstNameToIndex;
+ Slang::Misc::TokenReader tokenReader(definition);
+ // A leading "*" at the beginning of the snip modifies $resultType with
+ // a storage class.
+ if (tokenReader.AdvanceIf("*"))
+ {
+ auto storageToken = tokenReader.ReadWord();
+ snippet->resultStorageClass = translateStorageClass(storageToken);
+
+ }
+ while (!tokenReader.IsEnd())
+ {
+ SpvSnippet::ASMInst inst;
+ if (tokenReader.AdvanceIf("%"))
+ {
+ String instName = tokenReader.ReadToken().Content;
+ mapInstNameToIndex[instName] = (int)snippet->instructions.getCount();
+ tokenReader.Read(Slang::Misc::TokenType::OpAssign);
+ }
+ inst.opCode = (SpvWord)tokenReader.ReadInt();
+ bool insideOperandList = true;
+ while (insideOperandList)
+ {
+ ASMOperand operand = {ASMOperandType::SpvWord, 0};
+ switch (tokenReader.NextToken().Type)
+ {
+ case Slang::Misc::TokenType::Semicolon:
+ insideOperandList = false;
+ tokenReader.ReadToken();
+ break;
+ case Slang::Misc::TokenType::IntLiteral:
+ operand.type = SpvSnippet::ASMOperandType::SpvWord;
+ operand.content = tokenReader.ReadInt();
+ inst.operands.add(operand);
+ break;
+ case Slang::Misc::TokenType::OpMod:
+ {
+ operand.type = SpvSnippet::ASMOperandType::InstReference;
+ auto refName = tokenReader.ReadToken().Content;
+ if (!mapInstNameToIndex.TryGetValue(refName, operand.content))
+ {
+ SLANG_ASSERT(!"Invalid SPV ASM: referenced inst is not defined.");
+ }
+ inst.operands.add(operand);
+ }
+ break;
+ case Slang::Misc::TokenType::Identifier:
+ {
+ auto identifier = tokenReader.ReadToken().Content;
+ if (identifier.startsWith("_"))
+ {
+ operand.type = SpvSnippet::ASMOperandType::ObjectReference;
+ operand.content =
+ StringToInt(identifier.subString(1, identifier.getLength() - 1));
+ inst.operands.add(operand);
+ }
+ else if (identifier == "resultType")
+ {
+ operand.type = SpvSnippet::ASMOperandType::ResultTypeId;
+ operand.content = -1;
+ if (tokenReader.AdvanceIf("*"))
+ {
+ // A "*" at operand qualifies the use of `resultType` with
+ // a storage class, but does not modify `resultType` itself.
+ auto storageClass = tokenReader.ReadWord();
+ auto spvStorageClass = translateStorageClass(storageClass);
+ operand.content = spvStorageClass;
+ snippet->usedResultTypeStorageClasses.add(spvStorageClass);
+ }
+ inst.operands.add(operand);
+ }
+ else if (identifier == "resultId")
+ {
+ operand.type = SpvSnippet::ASMOperandType::ResultId;
+ inst.operands.add(operand);
+ }
+ else
+ {
+ SLANG_ASSERT(!"Invalid SPV ASM operand.");
+ }
+ }
+ break;
+ default:
+ insideOperandList = false;
+ break;
+ }
+ }
+ snippet->instructions.add(inst);
+ }
+ }
+ catch (const Slang::Misc::TextFormatException&)
+ {
+ SLANG_ASSERT(!"Invalid ASM format.");
+ }
+ return snippet;
+}
+
+
+}
diff --git a/source/slang/slang-ir-spirv-snippet.h b/source/slang/slang-ir-spirv-snippet.h
new file mode 100644
index 000000000..74a9b8cd7
--- /dev/null
+++ b/source/slang/slang-ir-spirv-snippet.h
@@ -0,0 +1,61 @@
+// slang-ir-spirv-legalize.h
+#pragma once
+#include "../core/slang-basic.h"
+#include "spirv/unified1/spirv.h"
+
+namespace Slang
+{
+//
+// [2.2: Terms]
+//
+// > Word: 32 bits.
+//
+// Despite the importance to SPIR-V, the `spirv.h` header doesn't
+// define a type for words, so we'll do it here.
+
+/// A SPIR-V word.
+typedef uint32_t SpvWord;
+
+/// Represents a parsed Spv ASM from intrinsic definition.
+struct SpvSnippet : public RefObject
+{
+ enum class ASMOperandType
+ {
+ // Plain SpvWord to inline without modifications.
+ SpvWord,
+ // Represents the result type of the intrinsic.
+ ResultTypeId,
+ // Represents the result Id of the ASM inst.
+ ResultId,
+ // Represents a reference to an intrinsic argument (e.g. `_1`).
+ ObjectReference,
+ // Represents a reference to an ASM inst (e.g. `%t`).
+ InstReference,
+ };
+
+ struct ASMOperand
+ {
+ ASMOperandType type;
+
+ // The value of the spv word when type is `SpvWord`, or
+ // the reference name when type is `ObjectReference`
+ // (e.g. an argument reference (_1) has `content` == 1).
+ int content;
+ };
+
+ struct ASMInst
+ {
+ SpvWord opCode;
+ List<ASMOperand> operands;
+ };
+
+ List<ASMInst> instructions;
+ List<SpvStorageClass> usedResultTypeStorageClasses;
+
+ SpvStorageClass resultStorageClass = SpvStorageClassMax;
+
+ static RefPtr<SpvSnippet> parse(UnownedStringSlice definition);
+};
+
+
+}
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index fe60fb480..60aaafa83 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2604,6 +2604,12 @@ namespace Slang
operands);
}
+ IRPtrType* IRBuilder::getPtrType(IROp op, IRType* valueType, IRIntegerValue addressSpace)
+ {
+ IRInst* operands[] = {valueType, getIntValue(getIntType(), addressSpace)};
+ return (IRPtrType*)getType(op, 2, operands);
+ }
+
IRArrayTypeBase* IRBuilder::getArrayTypeBase(
IROp op,
IRType* elementType,
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 9eb03c269..7542a883a 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1116,6 +1116,13 @@ struct IRPtrTypeBase : IRType
{
IRType* getValueType() { return (IRType*)getOperand(0); }
+ bool hasAddressSpace() { return getOperandCount() > 1; }
+
+ IRIntegerValue getAddressSpace()
+ {
+ return getOperandCount() > 1 ? static_cast<IRIntLit*>(getOperand(1))->getValue() : -1;
+ }
+
IR_PARENT_ISA(PtrTypeBase)
};
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index ef92558cc..913be346e 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -1100,6 +1100,11 @@ void TargetRequest::addCapability(CapabilityAtom capability)
cookedCapabilities = CapabilitySet::makeEmpty();
}
+void TargetRequest::setDirectSPIRVEmitMode()
+{
+ m_emitSPIRVDirectly = true;
+ cookedCapabilities.makeEmpty();
+}
CapabilitySet TargetRequest::getTargetCaps()
{
@@ -1131,9 +1136,18 @@ CapabilitySet TargetRequest::getTargetCaps()
case CodeGenTarget::GLSL:
case CodeGenTarget::GLSL_Vulkan:
case CodeGenTarget::GLSL_Vulkan_OneDesc:
+ atoms.add(CapabilityAtom::GLSL);
+ break;
case CodeGenTarget::SPIRV:
case CodeGenTarget::SPIRVAssembly:
- atoms.add(CapabilityAtom::GLSL);
+ if (m_emitSPIRVDirectly)
+ {
+ atoms.add(CapabilityAtom::SPIRV_DIRECT);
+ }
+ else
+ {
+ atoms.add(CapabilityAtom::GLSL);
+ }
break;
case CodeGenTarget::HLSL:
diff --git a/tests/spirv/direct-spirv-compute-simple.slang b/tests/spirv/direct-spirv-compute-simple.slang
new file mode 100644
index 000000000..39b9074ed
--- /dev/null
+++ b/tests/spirv/direct-spirv-compute-simple.slang
@@ -0,0 +1,23 @@
+// direct-spirv-compute-simple.slang
+
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -xslang -emit-spirv-directly
+
+// Test runinng a shader generated from direct SPIR-V emit.
+
+//TEST_INPUT:set resultBuffer = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<uint> resultBuffer;
+
+[numthreads(4,1,1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint threadId = dispatchThreadID.x;
+ uint result = threadId + 1;
+ result = result - 1;
+ result = result * 2;
+ result = result / 2;
+ result = result % 3;
+ result = (result ^ 7);
+ result = (result & 7);
+ result = (result | 8);
+ resultBuffer[threadId] = result;
+}
diff --git a/tests/spirv/direct-spirv-compute-simple.slang.expected.txt b/tests/spirv/direct-spirv-compute-simple.slang.expected.txt
new file mode 100644
index 000000000..4fc6bca7a
--- /dev/null
+++ b/tests/spirv/direct-spirv-compute-simple.slang.expected.txt
@@ -0,0 +1,4 @@
+F
+E
+D
+F \ No newline at end of file
diff --git a/tests/spirv/direct-spirv-control-flow-2.slang b/tests/spirv/direct-spirv-control-flow-2.slang
new file mode 100644
index 000000000..cc908100e
--- /dev/null
+++ b/tests/spirv/direct-spirv-control-flow-2.slang
@@ -0,0 +1,47 @@
+// direct-spirv-control-flow-2.slang
+
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -output-using-type -xslang -emit-spirv-directly
+
+// Test direct SPIR-V emit on control flows.
+
+//TEST_INPUT:set resultBuffer = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<uint> resultBuffer;
+
+uint test(uint p)
+{
+ int result = 0;
+ for (int i = 0; i < 5; i++)
+ {
+ result += i*2;
+ }
+ switch (p)
+ {
+ case 0:
+ result = result - 1;
+ break;
+ case 1:
+ result = result + 1;
+ break;
+ default:
+ result = result * 2;
+ break;
+ }
+ if (p > 2)
+ {
+ switch (p)
+ {
+ case 3:
+ result++;
+ break;
+ }
+ }
+ return result;
+}
+
+[numthreads(4,1,1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint threadId = dispatchThreadID.x;
+ uint result = test(threadId);
+ resultBuffer[threadId] = result;
+}
diff --git a/tests/spirv/direct-spirv-control-flow-2.slang.expected.txt b/tests/spirv/direct-spirv-control-flow-2.slang.expected.txt
new file mode 100644
index 000000000..36929d66f
--- /dev/null
+++ b/tests/spirv/direct-spirv-control-flow-2.slang.expected.txt
@@ -0,0 +1,5 @@
+type: uint32_t
+19
+21
+40
+41
diff --git a/tests/spirv/direct-spirv-control-flow.slang b/tests/spirv/direct-spirv-control-flow.slang
new file mode 100644
index 000000000..9efddeb12
--- /dev/null
+++ b/tests/spirv/direct-spirv-control-flow.slang
@@ -0,0 +1,30 @@
+// direct-spirv-control-flow.slang
+
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -xslang -emit-spirv-directly
+
+// Test direct SPIRV emit on control fl.
+
+//TEST_INPUT:set resultBuffer = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<uint> resultBuffer;
+
+uint test(uint p)
+{
+ int result = 0;
+ if (p == 0)
+ {
+ result = 5;
+ }
+ else
+ {
+ result = 6;
+ }
+ return result;
+}
+
+[numthreads(4,1,1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint threadId = dispatchThreadID.x;
+ uint result = test(threadId);
+ resultBuffer[threadId] = result;
+}
diff --git a/tests/spirv/direct-spirv-control-flow.slang.expected.txt b/tests/spirv/direct-spirv-control-flow.slang.expected.txt
new file mode 100644
index 000000000..c0bcc1c4a
--- /dev/null
+++ b/tests/spirv/direct-spirv-control-flow.slang.expected.txt
@@ -0,0 +1,4 @@
+5
+6
+6
+6
diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp
index 592cbaac1..88770dbb9 100644
--- a/tools/gfx/vulkan/render-vk.cpp
+++ b/tools/gfx/vulkan/render-vk.cpp
@@ -5573,6 +5573,7 @@ Result VKDevice::initVulkanInstanceAndDevice(bool useValidationLayer)
extendedFeatures.bufferDeviceAddressFeatures.pNext = (void*)deviceCreateInfo.pNext;
deviceCreateInfo.pNext = &extendedFeatures.bufferDeviceAddressFeatures;
deviceExtensions.add(VK_KHR_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME);
+
m_features.add("buffer-device-address");
}
@@ -5606,6 +5607,7 @@ Result VKDevice::initVulkanInstanceAndDevice(bool useValidationLayer)
deviceCreateInfo.enabledExtensionCount = uint32_t(deviceExtensions.getCount());
deviceCreateInfo.ppEnabledExtensionNames = deviceExtensions.getBuffer();
+
if (m_api.vkCreateDevice(m_api.m_physicalDevice, &deviceCreateInfo, nullptr, &m_device) != VK_SUCCESS)
return SLANG_FAIL;
SLANG_RETURN_ON_FAIL(m_api.initDeviceProcs(m_device));
diff --git a/tools/gfx/vulkan/vk-api.h b/tools/gfx/vulkan/vk-api.h
index 746648470..0bb3339fb 100644
--- a/tools/gfx/vulkan/vk-api.h
+++ b/tools/gfx/vulkan/vk-api.h
@@ -226,7 +226,7 @@ struct VulkanExtendedFeatureProperties
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES};
// Inline uniform block features
VkPhysicalDeviceInlineUniformBlockFeaturesEXT inlineUniformBlockFeatures = {
- VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES};
+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_INLINE_UNIFORM_BLOCK_FEATURES_EXT};
};
struct VulkanApi