diff options
| author | Anders Leino <aleino@nvidia.com> | 2024-09-09 20:08:29 +0300 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-09-09 10:08:29 -0700 |
| commit | 170558c9618252933286955c6d010c8e3735652a (patch) | |
| tree | 1dff2c0bb18a7244114a8a37950f0fbbd8dfa594 | |
| parent | 110d82fb75f19ac83e3a297e4783304481f66ce7 (diff) | |
Initial WGSL support (#5006)
* Add WGSL as a target
This is required for #4807.
* C-like emitter: Allow the function header emission to be overloaded
WGSL-style function headers are pretty different from normal C-style headers:
Normal C-style headers:
ReturnType Func(...)
void VoidFunc(...)
WGSL-style headers:
fn Func(...) -> ReturnType
fn VoidFunc(...)
This change allows the header style to be overloaded, in order to accomodate WGSL-style
headers as required to resolve issue #4807, but retains normal C-style headers as the
default implementation.
[1] https://www.w3.org/TR/WGSL/#function-declaration-sec
* C-like emitter: Allow emission of switch case selectors to be overloaded
The C-like emitter will emit code like this:
switch(a.x)
{
case 0:
case 1:
{
...
} break;
...
}
This is not allowed in WGSL. Instead, selectors for cases that share a body must [1] be
separated by commas, like this:
switch(a.x)
{
case 0, 1:
{
...
} break;
...
}
To prepare for addressing issue #4807, this patch makes the emission of switch case
selectors overloadable.
[1] https://www.w3.org/TR/WGSL/#syntax-case_selectors
* C-like emitter: Support WGSL-style declarations
This patch helps to address issue 4807.
C-like languages declare variables like this:
i32 a;
WGSL declares variables like this:
var a : i32
The patch introduces overloads so that the forthcoming WGSL emitter can output WGSL-style
declarations, which helps to resolve #4807.
* C-like emitter: Support overloading of declarators
Unlike C-like languages, WGSL does not support the following types at the syntax level,
via declarators:
- arrays
- pointers
- references
For this reason, this patch introduces support for overloading the declarator emitter,
in order to help address issue #4807.
C-like languages:
int a[3]; // Array-ness of type is mixed into the "declarator"
WGSL:
var a : array<int, 3>; // Array-ness of type is part of the... type_specifier!
* C-like emitter: Allow struct declaration separator to be overridden
C-like languages use ';' as a separator, and languages like e.g. WGSL use ','.
This change prepares for addressing issue #4807.
* C-like emitter: Allow overriding of whether pointer-like syntax is necessary
Things like e.g. structured buffers map to "ptr-to-array" in WGSL, but ptr-typed
expressions don't always need C-style pointer-like syntax.
Therefore, make it overrideable whether or not such syntax is emitted in various cases in
order to address #4807.
* C-like emitter: Emit parenthesis to avoid warning about & and + precedence
This helps with #4807 because WGSL compilers (e.g. Tint) treat absence of parenthesis as
an error.
* C-like emitter: Add hook for emitting struct field attributes
WGSL requires @align attributes to specify explicit field alignment in certain cases.
Thus, this patch prepares for addressing #4807.
* C-like emitter: Add hook for emitting global param types
Declarations of structured buffers map to global array declarations in WGSL.
However, in all other cases such as when structured buffers are used in operands, their
types map to *ptr*-to-array.
This patch makes it possible for the WGSL back-end to say that structured buffers
generally map to "ptr-to-array" types, but still have a special case of just "array" when
declaring the global shader parameter.
Thus, this patch helps with addressing #4807.
* IR lowering: Use std140 for WGSL uniform buffers
This patch just cuts out some logic that prevented std140 to be chosen for WGSL uniform
buffers.
Note that WGSL buffers in the uniform address space is not quite std140, but for now it's
close enough to avoid compile issues.
Later on, a custom layout should be created for WGSL uniform buffers.
When that's done, this change will be revisited, but for now it helps to resolve #4807.
* Don't emit line directives in WGSL by default
WGSL does not support line directives [1].
The plan currently seems to be to instead support source-map [2].
This is part of addressing issue #4807.
[1] https://github.com/gpuweb/gpuweb/issues/606
[2] https://github.com/mozilla/source-map
* WGSL IR legalization: Map SV's
The implementation closely follows the cooresponding one for Metal.
Supported:
- DispatchThreadID
- GroupID
- GroupThreadID
- GroupThreadID
Unsupported:
- GSInstanceID
This is not complete, but it helps to address #4807.
* WGSL emitter: Add support for basic language constructs
A lot of the basics are added in order to generate correct WGSL code for basic Slang language constructs.
This addresses issue #4807.
This adds support for at least the following:
- statments
- if statements
- ternary operator
- while statement
- for statements
- variable declarations
- switch statements
- Note: Slang may emit non-constant case expressions, see issue 4834
- literals
- integer literals
- u?int[16|32|64]_t
- float and half literals
- bool literals
- vector literals and splatting (e.g 1.xxx)
- function definitions
- assignments
- +=, *=, /=
- array assignments
- vector assignments/updates
- swizzles of other vectors
- from matrix rows ('m[i]' notation)
- from matrix cols (using swizzle notation, e.g 'm._11_12_13')
- matrix assignments/updates
- to rows ('m[i]' notation)
- to cols (using swizzle notation, e.g 'm._11_12_13')
- declarations
- arrays
[1] https://www.w3.org/TR/WGSL/#syntax-switch_body
* Add some WGSL capabilities
This patch registers some WGSL capabilities required to pass many of the initial compute
shader compile tests.
Many capabilities still remain to be added -- this is just an initial set to help resolve
issue #4807.
- asint
- min and max
- cos and sin
- all and any
* WGSL and C-like emitters: Add hack to bitcast case expression
In WGSL, the switch condition and case types must match.
https://www.w3.org/TR/WGSL/#switch-statement
Slang currently allows these types to mismatch, as pointed out in #4921.
Issue #4921 should eventually be addressed in the front-end by a patch like [1].
However, at the moment that would break Falcor tests.
Thus, this patch temporarily works around the issue in the WGSL emitter only in order to
help resolve #4807.
In the future, the Falcor tests should be fixed, this patch should be dropped and [1]
should be merged instead.
[1] a32156ef52f43b8503b2c77f2f1d51220ab9bdea
22 files changed, 1630 insertions, 56 deletions
diff --git a/include/slang.h b/include/slang.h index 9755415b3..777cd406b 100644 --- a/include/slang.h +++ b/include/slang.h @@ -603,6 +603,7 @@ extern "C" SLANG_METAL_LIB, ///< Metal library SLANG_METAL_LIB_ASM, ///< Metal library assembly SLANG_HOST_SHARED_LIBRARY, ///< A shared library/Dll for host code (for hosting CPU/OS) + SLANG_WGSL, ///< WebGPU shading language SLANG_TARGET_COUNT_OF, }; @@ -735,6 +736,7 @@ extern "C" SLANG_SOURCE_LANGUAGE_CUDA, SLANG_SOURCE_LANGUAGE_SPIRV, SLANG_SOURCE_LANGUAGE_METAL, + SLANG_SOURCE_LANGUAGE_WGSL, SLANG_SOURCE_LANGUAGE_COUNT_OF, }; diff --git a/source/compiler-core/slang-artifact-desc-util.cpp b/source/compiler-core/slang-artifact-desc-util.cpp index a4190992c..9794cc90e 100644 --- a/source/compiler-core/slang-artifact-desc-util.cpp +++ b/source/compiler-core/slang-artifact-desc-util.cpp @@ -197,6 +197,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactKind, SLANG_ARTIFACT_KIND, SLANG_ARTIFACT_KIND_E x(CUDA, Source) \ x(Metal, Source) \ x(Slang, Source) \ + x(WGSL, Source) \ x(KernelLike, Base) \ x(DXIL, KernelLike) \ x(DXBC, KernelLike) \ @@ -288,6 +289,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL case SLANG_METAL: return Desc::make(Kind::Source, Payload::Metal, Style::Kernel, 0); case SLANG_METAL_LIB: return Desc::make(Kind::Executable, Payload::MetalAIR, Style::Kernel, 0); case SLANG_METAL_LIB_ASM: return Desc::make(Kind::Assembly, Payload::MetalAIR, Style::Kernel, 0); + case SLANG_WGSL: return Desc::make(Kind::Source, Payload::WGSL, Style::Kernel, 0); default: break; } @@ -330,6 +332,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL case Payload::Cpp: return (desc.style == Style::Host) ? SLANG_HOST_CPP_SOURCE : SLANG_CPP_SOURCE; case Payload::CUDA: return SLANG_CUDA_SOURCE; case Payload::Metal: return SLANG_METAL; + case Payload::WGSL: return SLANG_WGSL; default: break; } break; diff --git a/source/compiler-core/slang-artifact.h b/source/compiler-core/slang-artifact.h index 400c85b2e..6d65aafba 100644 --- a/source/compiler-core/slang-artifact.h +++ b/source/compiler-core/slang-artifact.h @@ -143,6 +143,7 @@ enum class ArtifactPayload : uint8_t CUDA, ///< CUDA source Metal, ///< Metal source Slang, ///< Slang source + WGSL, ///< WGSL source KernelLike, ///< GPU Kernel like diff --git a/source/core/slang-type-text-util.cpp b/source/core/slang-type-text-util.cpp index 9fa91abf6..9f9deb92c 100644 --- a/source/core/slang-type-text-util.cpp +++ b/source/core/slang-type-text-util.cpp @@ -63,6 +63,7 @@ static const TypeTextUtil::CompileTargetInfo s_compileTargetInfos[] = { SLANG_METAL, "metal", "metal", "Metal shader source" }, { SLANG_METAL_LIB, "metallib", "metallib", "Metal Library Bytecode" }, { SLANG_METAL_LIB_ASM, "metallib-asm" "metallib-asm", "Metal Library Bytecode assembly" }, + { SLANG_WGSL, "wgsl", "wgsl", "WebGPU shading language source" }, }; static const NamesDescriptionValue s_languageInfos[] = diff --git a/source/slang-record-replay/util/emum-to-string.h b/source/slang-record-replay/util/emum-to-string.h index 7a7952555..7226edc04 100644 --- a/source/slang-record-replay/util/emum-to-string.h +++ b/source/slang-record-replay/util/emum-to-string.h @@ -34,6 +34,7 @@ namespace SlangRecord CASE(SLANG_METAL_LIB); CASE(SLANG_METAL_LIB_ASM); CASE(SLANG_HOST_SHARED_LIBRARY); + CASE(SLANG_WGSL); CASE(SLANG_TARGET_COUNT_OF); default: Slang::StringBuilder str; diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 10a6254c1..a8241bf73 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -5668,7 +5668,7 @@ vector<T,N> acosh(vector<T,N> x) // Test if all components are non-zero (HLSL SM 1.0) __generic<T : __BuiltinType> [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] bool all(T x) { __target_switch @@ -5679,6 +5679,8 @@ bool all(T x) __intrinsic_asm "all"; case metal: __intrinsic_asm "all"; + case wgsl: + __intrinsic_asm "all"; case spirv: let zero = __default<T>(); if (__isInt<T>()) @@ -5806,7 +5808,7 @@ int3 WorkgroupSize(); __generic<T : __BuiltinType> [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] bool any(T x) { __target_switch @@ -5817,6 +5819,8 @@ bool any(T x) __intrinsic_asm "any"; case metal: __intrinsic_asm "any"; + case wgsl: + __intrinsic_asm "any"; case spirv: let zero = __default<T>(); if (__isInt<T>()) @@ -6142,7 +6146,7 @@ vector<T,N> asinh(vector<T,N> x) // Reinterpret bits as an int (HLSL SM 4.0) [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)] int asint(float x) { __target_switch @@ -6152,6 +6156,7 @@ int asint(float x) case glsl: __intrinsic_asm "floatBitsToInt"; case hlsl: __intrinsic_asm "asint"; case metal: __intrinsic_asm "as_type<$TR>($0)"; + case wgsl: __intrinsic_asm "bitcast<$TR>($0)"; case spirv: return spirv_asm { OpBitcast $$int result $x }; @@ -6285,7 +6290,7 @@ void asuint(double value, out uint lowbits, out uint highbits) // Reinterpret bits as a uint (HLSL SM 4.0) [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)] uint asuint(float x) { __target_switch @@ -6295,6 +6300,7 @@ uint asuint(float x) case glsl: __intrinsic_asm "floatBitsToUint"; case hlsl: __intrinsic_asm "asuint"; case metal: __intrinsic_asm "as_type<$TR>($0)"; + case wgsl: __intrinsic_asm "bitcast<$TR>($0)"; case spirv: return spirv_asm { OpBitcast $$uint result $x }; @@ -7025,7 +7031,7 @@ void clip(matrix<T,N,M> x) // Cosine __generic<T : __BuiltinFloatingPointType> [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T cos(T x) { __target_switch @@ -7035,6 +7041,7 @@ T cos(T x) case glsl: __intrinsic_asm "cos"; case hlsl: __intrinsic_asm "cos"; case metal: __intrinsic_asm "cos"; + case wgsl: __intrinsic_asm "cos"; case spirv: return spirv_asm { OpExtInst $$T result glsl450 Cos $x }; @@ -10427,7 +10434,7 @@ matrix<T, N, M> mad(matrix<T, N, M> mvalue, matrix<T, N, M> avalue, matrix<T, N, // maximum __generic<T : __BuiltinIntegerType> [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T max(T x, T y) { // Note: a stdlib implementation of `max` (or `min`) will require splitting @@ -10440,6 +10447,7 @@ T max(T x, T y) case hlsl: __intrinsic_asm "max"; case glsl: __intrinsic_asm "max"; case metal: __intrinsic_asm "max"; + case wgsl: __intrinsic_asm "max"; case cuda: __intrinsic_asm "$P_max($0, $1)"; case cpp: __intrinsic_asm "$P_max($0, $1)"; case spirv: @@ -10656,7 +10664,7 @@ vector<T,N> fmax3(vector<T,N> x, vector<T,N> y, vector<T,N> z) // minimum __generic<T : __BuiltinIntegerType> [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T min(T x, T y) { __target_switch @@ -10664,6 +10672,7 @@ T min(T x, T y) case hlsl: case glsl: case metal: + case wgsl: __intrinsic_asm "min"; case cuda: case cpp: @@ -11103,13 +11112,14 @@ T mul(vector<T, N> x, vector<T, N> y) // vector-matrix __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] vector<T, M> mul(vector<T, N> left, matrix<T, N, M> right) { __target_switch { case glsl: __intrinsic_asm "($1 * $0)"; case metal: __intrinsic_asm "($1 * $0)"; + case wgsl: __intrinsic_asm "($1 * $0)"; case hlsl: __intrinsic_asm "mul"; case spirv: return spirv_asm { OpMatrixTimesVector $$vector<T, M> result $right $left @@ -12166,7 +12176,7 @@ matrix<int, N, M> sign(matrix<T, N, M> x) __generic<T : __BuiltinFloatingPointType> [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] T sin(T x) { __target_switch @@ -12176,6 +12186,7 @@ T sin(T x) case glsl: __intrinsic_asm "sin"; case hlsl: __intrinsic_asm "sin"; case metal: __intrinsic_asm "sin"; + case wgsl: __intrinsic_asm "sin"; case spirv: return spirv_asm { OpExtInst $$T result glsl450 Sin $x }; diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index a173a332f..9e9b94151 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -111,6 +111,10 @@ def metal : target + textualTarget; /// [Target] def spirv : target; +/// Represents the WebGPU shading language code generation target. +/// [Target] +def wgsl : target + textualTarget; + // Capabilities that stand for target SPIR-V versions for the GLSL backend. // These are not compilation targets. We will convert `_spirv_*` to `glsl_spirv_*` during compilation. @@ -228,15 +232,15 @@ def _cuda_sm_9_0 : _cuda_sm_8_0; /// All code-gen targets /// [Compound] -alias any_target = hlsl | metal | glsl | c | cpp | cuda | spirv; +alias any_target = hlsl | metal | glsl | c | cpp | cuda | spirv | wgsl; /// All non-asm code-gen targets /// [Compound] -alias any_textual_target = hlsl | metal | glsl | c | cpp | cuda; +alias any_textual_target = hlsl | metal | glsl | c | cpp | cuda | wgsl; /// All slang-gfx compatible code-gen targets /// [Compound] -alias any_gfx_target = hlsl | metal | glsl | spirv; +alias any_gfx_target = hlsl | metal | glsl | spirv | wgsl; /// All "cpp syntax" code-gen targets /// [Compound] @@ -266,6 +270,10 @@ alias cpp_cuda_glsl_hlsl_spirv = cpp | cuda | glsl | hlsl | spirv; /// [Compound] alias cpp_cuda_glsl_hlsl_metal_spirv = cpp | cuda | glsl | hlsl | metal | spirv; +/// CPP, CUDA, GLSL, HLSL, Metal, SPIRV and WGSL code-gen targets +/// [Compound] +alias cpp_cuda_glsl_hlsl_metal_spirv_wgsl = cpp | cuda | glsl | hlsl | metal | spirv | wgsl; + /// CPP, CUDA, and HLSL code-gen targets /// [Compound] alias cpp_cuda_hlsl = cpp | cuda | hlsl; @@ -1178,6 +1186,7 @@ alias sm_4_0_version = _sm_4_0 | spirv_1_0 | _cuda_sm_2_0 | metal + | wgsl | cpp ; diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 4bb420fa7..541085b4e 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -1715,6 +1715,7 @@ namespace Slang case CodeGenTarget::PyTorchCppBinding: case CodeGenTarget::CSource: case CodeGenTarget::Metal: + case CodeGenTarget::WGSL: { RefPtr<ExtensionTracker> extensionTracker = _newExtensionTracker(target); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index b8ee4dc9c..62e4c5f4a 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -94,6 +94,7 @@ namespace Slang Metal = SLANG_METAL, MetalLib = SLANG_METAL_LIB, MetalLibAssembly = SLANG_METAL_LIB_ASM, + WGSL = SLANG_WGSL, CountOf = SLANG_TARGET_COUNT_OF, }; diff --git a/source/slang/slang-doc-markdown-writer.cpp b/source/slang/slang-doc-markdown-writer.cpp index c9dd6d9c8..e32a738c7 100644 --- a/source/slang/slang-doc-markdown-writer.cpp +++ b/source/slang/slang-doc-markdown-writer.cpp @@ -483,6 +483,10 @@ static DocMarkdownWriter::Requirement _getRequirementFromTargetToken(const Token { return Requirement{ CodeGenTarget::Metal, targetName }; } + else if (isCapabilityDerivedFrom(targetCap, CapabilityAtom::wgsl)) + { + return Requirement{ CodeGenTarget::WGSL, targetName }; + } return Requirement{ CodeGenTarget::Unknown, String() }; } diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 1893929f8..caf3613a7 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -95,6 +95,10 @@ struct CLikeSourceEmitter::ComputeEmitActionsContext { return SourceLanguage::Metal; } + case CodeGenTarget::WGSL: + { + return SourceLanguage::WGSL; + } } } @@ -151,7 +155,7 @@ void CLikeSourceEmitter::ensureTypePrelude(IRType* type) } } -void CLikeSourceEmitter::emitDeclarator(DeclaratorInfo* declarator) +void CLikeSourceEmitter::emitDeclaratorImpl(DeclaratorInfo* declarator) { if (!declarator) return; @@ -341,13 +345,18 @@ void CLikeSourceEmitter::_emitPostfixTypeAttr(IRAttr* attr) // we may need to handle it here. } +void CLikeSourceEmitter::emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator) +{ + emitSimpleType(type); + emitDeclarator(declarator); +} + void CLikeSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator) { switch (type->getOp()) { default: - emitSimpleType(type); - emitDeclarator(declarator); + emitSimpleTypeAndDeclarator(type, declarator); break; case kIROp_RateQualifiedType: @@ -648,7 +657,7 @@ bool CLikeSourceEmitter::maybeEmitParens(EmitOpInfo& outerPrec, const EmitOpInfo bool needParens = (prec.leftPrecedence <= outerPrec.leftPrecedence) || (prec.rightPrecedence <= outerPrec.rightPrecedence); - // While Slang correctly removes some of parentheses, DXC prints warnings + // While Slang correctly removes some of parentheses, many compilers print warnings // for common mistakes when parentheses are not used with certain combinations // of the operations. We emit parentheses to avoid the warnings. // @@ -676,6 +685,12 @@ bool CLikeSourceEmitter::maybeEmitParens(EmitOpInfo& outerPrec, const EmitOpInfo { needParens = true; } + // a + b & c => (a + b) & c + else if (prec.rightPrecedence == EPrecedence::kEPrecedence_Additive_Right + && outerPrec.rightPrecedence == EPrecedence::kEPrecedence_BitAnd_Left) + { + needParens = true; + } if (needParens) { @@ -1657,11 +1672,16 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) return true; } +bool CLikeSourceEmitter::isPointerSyntaxRequiredImpl(IRInst* /* inst */) +{ + return doesTargetSupportPtrTypes(); +} + void CLikeSourceEmitter::emitDereferenceOperand(IRInst* inst, EmitOpInfo const& outerPrec) { EmitOpInfo newOuterPrec = outerPrec; - if (doesTargetSupportPtrTypes()) + if (isPointerSyntaxRequiredImpl(inst)) { switch (inst->getOp()) { @@ -1760,7 +1780,7 @@ void CLikeSourceEmitter::emitDereferenceOperand(IRInst* inst, EmitOpInfo const& void CLikeSourceEmitter::emitVarExpr(IRInst* inst, EmitOpInfo const& outerPrec) { - if (doesTargetSupportPtrTypes()) + if (isPointerSyntaxRequiredImpl(inst)) { auto prec = getInfo(EmitOp::Prefix); auto newOuterPrec = outerPrec; @@ -1842,7 +1862,8 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst) emitRateQualifiers(inst); - if(as<IRModuleInst>(inst->getParent())) + bool isConstant(as<IRModuleInst>(inst->getParent())); + if(isConstant) { // "Ordinary" instructions at module scope are constants @@ -1857,6 +1878,9 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst) case SourceLanguage::Metal: m_writer->emit("constant "); break; + case SourceLanguage::WGSL: + // This is handled by emitVarKeyword, below + break; default: m_writer->emit("const "); break; @@ -1864,6 +1888,8 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst) } + emitVarKeyword(type, isConstant); + emitType(type, getName(inst)); m_writer->emit(" = "); } @@ -2297,7 +2323,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO IRFieldAddress* ii = (IRFieldAddress*) inst; - if (doesTargetSupportPtrTypes()) + if (isPointerSyntaxRequiredImpl(inst)) { auto prec = getInfo(EmitOp::Prefix); needClose = maybeEmitParens(outerPrec, prec); @@ -3117,6 +3143,8 @@ void CLikeSourceEmitter::_emitStoreImpl(IRStore* store) void CLikeSourceEmitter::_emitInstAsDefaultInitializedVar(IRInst* inst, IRType* type) { + emitVarKeyword(type, /* isConstant */ false); + emitType(type, getName(inst)); // On targets that support empty initializers, we will emit it. @@ -3178,6 +3206,20 @@ void CLikeSourceEmitter::emitLayoutSemantics(IRInst* inst, char const* uniformSe emitLayoutSemanticsImpl(inst, uniformSemanticSpelling, EmitLayoutSemanticOption::kPostType); } +void CLikeSourceEmitter::emitSwitchCaseSelectorsImpl(IRBasicType *const /* switchCondition */, const SwitchRegion::Case *const currentCase, const bool isDefault) +{ + for(auto caseVal : currentCase->values) + { + m_writer->emit("case "); + emitOperand(caseVal, getInfo(EmitOp::General)); + m_writer->emit(":\n"); + } + if(isDefault) + { + m_writer->emit("default:\n"); + } +} + void CLikeSourceEmitter::emitRegion(Region* inRegion) { // We will use a loop so that we can process sequential (simple) @@ -3333,17 +3375,9 @@ void CLikeSourceEmitter::emitRegion(Region* inRegion) auto defaultCase = switchRegion->defaultCase; for(auto currentCase : switchRegion->cases) { - for(auto caseVal : currentCase->values) - { - m_writer->emit("case "); - emitOperand(caseVal, getInfo(EmitOp::General)); - m_writer->emit(":\n"); - } - if(currentCase.Ptr() == defaultCase) - { - m_writer->emit("default:\n"); - } - + const bool isDefault {currentCase.Ptr() == defaultCase}; + IRBasicType *const switchConditionType {as<IRBasicType>(switchRegion->getCondition()->getDataType())}; + emitSwitchCaseSelectors(switchConditionType, currentCase.Ptr(), isDefault); m_writer->indent(); m_writer->emit("{\n"); m_writer->indent(); @@ -3449,9 +3483,16 @@ void CLikeSourceEmitter::emitSimpleFuncParamsImpl(IRFunc* func) m_writer->emit(")"); } -void CLikeSourceEmitter::emitSimpleFuncImpl(IRFunc* func) +void CLikeSourceEmitter::emitFuncHeaderImpl(IRFunc* func) { auto resultType = func->getResultType(); + auto name = getName(func); + emitType(resultType, name); + emitSimpleFuncParamsImpl(func); +} + +void CLikeSourceEmitter::emitSimpleFuncImpl(IRFunc* func) +{ // Deal with decorations that need // to be emitted as attributes @@ -3467,12 +3508,8 @@ void CLikeSourceEmitter::emitSimpleFuncImpl(IRFunc* func) emitFunctionPreambleImpl(func); - auto name = getName(func); - emitFuncDecorations(func); - - emitType(resultType, name); - emitSimpleFuncParamsImpl(func); + emitFuncHeader(func); emitSemantics(func); // TODO: encode declaration vs. definition @@ -3688,6 +3725,11 @@ void CLikeSourceEmitter::emitStruct(IRStructType* structType) m_writer->emit(";\n\n"); } +void CLikeSourceEmitter::emitStructDeclarationSeparatorImpl() +{ + m_writer->emit(";"); +} + void CLikeSourceEmitter::emitStructDeclarationsBlock(IRStructType* structType, bool allowOffsetLayout) { m_writer->emit("\n{\n"); @@ -3716,11 +3758,13 @@ void CLikeSourceEmitter::emitStructDeclarationsBlock(IRStructType* structType, b emitPackOffsetModifier(fieldKey, fieldType, packOffsetDecoration); } } + emitStructFieldAttributes(structType, ff); emitMemoryQualifiers(fieldKey); emitType(fieldType, getName(fieldKey)); emitSemantics(fieldKey, allowOffsetLayout); emitPostDeclarationAttributesForType(fieldType); - m_writer->emit(";\n"); + emitStructDeclarationSeparator(); + m_writer->emit("\n"); } m_writer->dedent(); @@ -3931,6 +3975,8 @@ void CLikeSourceEmitter::emitParameterGroup(IRGlobalParam* varDecl, IRUniformPar emitParameterGroupImpl(varDecl, type); } +void CLikeSourceEmitter::emitVarKeywordImpl(IRType * /* type */, bool /* isConstant */) {} + void CLikeSourceEmitter::emitVar(IRVar* varDecl) { auto allocatedType = varDecl->getDataType(); @@ -3969,6 +4015,8 @@ void CLikeSourceEmitter::emitVar(IRVar* varDecl) #endif emitRateQualifiersAndAddressSpace(varDecl); + emitVarKeyword(varType, /* isConstant */ false); + emitType(varType, getName(varDecl)); emitSemantics(varDecl); @@ -4099,6 +4147,7 @@ void CLikeSourceEmitter::emitGlobalVar(IRGlobalVar* varDecl) emitVarModifiers(layout, varDecl, varType); emitRateQualifiersAndAddressSpace(varDecl); + emitVarKeyword(varType, /* isConstant */ true); emitType(varType, getName(varDecl)); // TODO: These shouldn't be needed for ordinary @@ -4172,7 +4221,8 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl) emitDecorationLayoutSemantics(varDecl, "register"); emitRateQualifiersAndAddressSpace(varDecl); - emitType(varType, getName(varDecl)); + emitVarKeyword(varType, /* isConstant */ false); + emitGlobalParamType(varType, getName(varDecl)); emitSemantics(varDecl); diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 00ad156d1..be769f31f 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -248,7 +248,8 @@ public: // void ensureTypePrelude(IRType* type); - void emitDeclarator(DeclaratorInfo* declarator); + void emitDeclarator(DeclaratorInfo* declarator) {emitDeclaratorImpl(declarator);} + virtual void emitDeclaratorImpl(DeclaratorInfo* declarator); void emitType(IRType* type, const StringSliceLoc* nameLoc) { emitTypeImpl(type, nameLoc); } void emitType(IRType* type, Name* name); @@ -256,6 +257,7 @@ public: void emitType(IRType* type); void emitType(IRType* type, Name* name, SourceLoc const& nameLoc); void emitType(IRType* type, NameLoc const& nameAndLoc); + virtual void emitGlobalParamType(IRType* type, String const& name) {emitType(type, name);} bool hasExplicitConstantBufferOffset(IRInst* cbufferType); bool isSingleElementConstantBuffer(IRInst* cbufferType); bool shouldForceUnpackConstantBufferElements(IRInst* cbufferType); @@ -368,8 +370,11 @@ public: /// Emit high-level statements for the body of a function. void emitFunctionBody(IRGlobalValueWithCode* code); + void emitFuncHeader(IRFunc* func) { emitFuncHeaderImpl(func); } void emitSimpleFunc(IRFunc* func) { emitSimpleFuncImpl(func); } + void emitSwitchCaseSelectors(IRBasicType *const switchConditionType, const SwitchRegion::Case *const currentCase, const bool isDefault) {emitSwitchCaseSelectorsImpl(switchConditionType, currentCase, isDefault);} + void emitParamType(IRType* type, String const& name) { emitParamTypeImpl(type, name); } void emitFuncDecl(IRFunc* func); @@ -394,10 +399,14 @@ public: void emitStructDeclarationsBlock(IRStructType* structType, bool allowOffsetLayout); void emitClass(IRClassType* structType); + void emitStructDeclarationSeparator() {emitStructDeclarationSeparatorImpl();} + virtual void emitStructDeclarationSeparatorImpl(); + /// Emit type attributes that should appear after, e.g., a `struct` keyword void emitPostKeywordTypeAttributes(IRInst* inst) { emitPostKeywordTypeAttributesImpl(inst); } virtual void emitMemoryQualifiers(IRInst* /*varInst*/) {}; + virtual void emitStructFieldAttributes(IRStructType * /* structType */, IRStructField * /* field */) {}; void emitInterpolationModifiers(IRInst* varInst, IRType* valueType, IRVarLayout* layout); void emitMeshShaderModifiers(IRInst* varInst); virtual void emitPackOffsetModifier(IRInst* /*varInst*/, IRType* /*valueType*/, IRPackOffsetDecoration* /*decoration*/) {}; @@ -421,6 +430,7 @@ public: void emitGlobalInst(IRInst* inst); virtual void emitGlobalInstImpl(IRInst* inst); + virtual bool isPointerSyntaxRequiredImpl(IRInst* inst); void ensureInstOperand(ComputeEmitActionsContext* ctx, IRInst* inst, EmitAction::Level requiredLevel = EmitAction::Level::Definition); @@ -486,6 +496,11 @@ public: virtual void emitPreModuleImpl(); virtual void emitPostModuleImpl(); + virtual void emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator); + void emitSimpleTypeAndDeclarator(IRType* type, DeclaratorInfo* declarator) {emitSimpleTypeAndDeclaratorImpl(type, declarator);}; + virtual void emitVarKeywordImpl(IRType * type, bool isConstant); + void emitVarKeyword(IRType * type, bool isConstant) {emitVarKeywordImpl(type, isConstant);} + virtual void beforeComputeEmitActions(IRModule* module) { SLANG_UNUSED(module); }; virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) { SLANG_UNUSED(rate); SLANG_UNUSED(addressSpace); } @@ -501,6 +516,7 @@ public: virtual void emitTypeImpl(IRType* type, const StringSliceLoc* nameLoc); virtual void emitSimpleValueImpl(IRInst* inst); virtual void emitModuleImpl(IRModule* module, DiagnosticSink* sink); + virtual void emitFuncHeaderImpl(IRFunc* func); virtual void emitSimpleFuncImpl(IRFunc* func); virtual void emitVarExpr(IRInst* inst, EmitOpInfo const& outerPrec); virtual void emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPrec); @@ -511,6 +527,7 @@ public: virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) { SLANG_UNUSED(decl); } virtual void emitIfDecorationsImpl(IRIfElse* ifInst) { SLANG_UNUSED(ifInst); } virtual void emitSwitchDecorationsImpl(IRSwitch* switchInst) { SLANG_UNUSED(switchInst); } + virtual void emitSwitchCaseSelectorsImpl(IRBasicType *const switchConditionType, const SwitchRegion::Case *const currentCase, const bool isDefault); virtual void emitFuncDecorationImpl(IRDecoration* decoration) { SLANG_UNUSED(decoration); } virtual void emitLivenessImpl(IRInst* inst); diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp new file mode 100644 index 000000000..0a4cca407 --- /dev/null +++ b/source/slang/slang-emit-wgsl.cpp @@ -0,0 +1,1005 @@ +#include "slang-emit-wgsl.h" + +// A note on row/column "terminology reversal". +// +// This is an "terminology reversing" implementation in the sense that +// * "column" in Slang code maps to "row" in the generated WGSL code, and +// * "row" in Slang code maps to "column" in the generated WGSL code. +// +// This means that matrices in Slang code end up getting translated to +// matrices that actually represent the transpose of what the Slang matrix +// represented. +// Both API's adopt the standard matrix multiplication convention whereby the +// column count of the matrix on the left hand side needs to match row count of +// the matrix on the right hand side. +// For these reasons, and due to the fact that (M_1 ... M_n)^T = M_n^T ... M_1^T, +// the order of matrix (and vector-matrix products) products must also reversed +// in the WGSL code. +// +// This may lead to confusion (which is why this note is referenced in several +// places), but the benefit of doing this is that the generated WGSL code is +// simpler to generate and should be faster to compile. +// A "terminology preserving" implementation would have to generate lots of +// 'transpose' calls, or else perform more complicated transformations that +// end up duplicating expressions many times. + +namespace Slang { + +void WGSLSourceEmitter::emitSwitchCaseSelectorsImpl( + IRBasicType *const switchConditionType, + const SwitchRegion::Case *const currentCase, const bool isDefault + ) +{ + // WGSL has special syntax for blocks sharing case labels: + // "case 2, 3, 4: ...;" instead of the C-like syntax + // "case 2: case 3: case 4: ...;". + + m_writer->emit("case "); + for (auto caseVal : currentCase->values) + { + // TODO: Fix this in the front-end [1], remove the if-path and just do the else-path. + // We can't do that at the moment because it would break Falcor [2]. + // [1] https://github.com/shader-slang/slang/pull/5025/commits/a32156ef52f43b8503b2c77f2f1d51220ab9bdea + // [2] https://github.com/shader-slang/slang/pull/5025#issuecomment-2334495120 + if (caseVal->getOp() == kIROp_IntLit) + { + auto caseLitInst = static_cast<IRConstant*>(caseVal); + IRBasicType *const caseInstType = as<IRBasicType>(caseLitInst->getDataType()); + // WGSL doesn't allow switch condition and case type mismatches, see [1]. + // Thus we need to insert explicit conversions. + // Doing a wrapping cast will match Slang's de facto semantics, according to + // [2]. + // (This is just a bitcast, assuming a two's complement representation.) + // [1] https://www.w3.org/TR/WGSL/#switch-statement + // [2] https://github.com/shader-slang/slang/issues/4921 + const bool needBitcast = + caseInstType->getBaseType() != switchConditionType->getBaseType(); + if (needBitcast) + { + m_writer->emit("bitcast<"); + emitType(switchConditionType); + m_writer->emit(">("); + } + emitOperand(caseVal, getInfo(EmitOp::General)); + if (needBitcast) + { + m_writer->emit(")"); + } + } + else + { + emitOperand(caseVal, getInfo(EmitOp::General)); + } + m_writer->emit(", "); + } + if (isDefault) + { + m_writer->emit("default, "); + } + m_writer->emit(":\n"); +} + +void WGSLSourceEmitter::emitParameterGroupImpl( + IRGlobalParam* varDecl, IRUniformParameterGroupType* type +) +{ + auto varLayout = getVarLayout(varDecl); + SLANG_RELEASE_ASSERT(varLayout); + + for (auto attr : varLayout->getOffsetAttrs()) + { + + const LayoutResourceKind kind = attr->getResourceKind(); + switch (kind) + { + case LayoutResourceKind::VaryingInput: + case LayoutResourceKind::VaryingOutput: + m_writer->emit("@location("); + m_writer->emit(attr->getOffset()); + m_writer->emit(")"); + if (attr->getSpace()) + { + // TODO: Not sure what 'space' should map to in WGSL + SLANG_ASSERT(false); + } + break; + + case LayoutResourceKind::SpecializationConstant: + // TODO: + // Consider moving to a differently named function. + // This is not technically an attribute, but a declaration. + // + // https://www.w3.org/TR/WGSL/#override-decls + m_writer->emit("override"); + break; + + case LayoutResourceKind::Uniform: + case LayoutResourceKind::ConstantBuffer: + case LayoutResourceKind::ShaderResource: + case LayoutResourceKind::UnorderedAccess: + case LayoutResourceKind::SamplerState: + case LayoutResourceKind::DescriptorTableSlot: + m_writer->emit("@binding("); + m_writer->emit(attr->getOffset()); + m_writer->emit(") "); + m_writer->emit("@group("); + m_writer->emit(attr->getSpace()); + m_writer->emit(") "); + break; + + } + + } + + auto elementType = type->getElementType(); + m_writer->emit("var<uniform> "); + m_writer->emit(getName(varDecl)); + m_writer->emit(" : "); + emitType(elementType); + m_writer->emit(";\n"); +} + +void WGSLSourceEmitter::emitEntryPointAttributesImpl( + IRFunc* irFunc, IREntryPointDecoration* entryPointDecor + ) +{ + auto stage = entryPointDecor->getProfile().getStage(); + + switch (stage) + { + + case Stage::Fragment: + m_writer->emit("@fragment\n"); + break; + case Stage::Vertex: + m_writer->emit("@vertex\n"); + break; + + case Stage::Compute: + { + m_writer->emit("@compute\n"); + + { + Int sizeAlongAxis[kThreadGroupAxisCount]; + getComputeThreadGroupSize(irFunc, sizeAlongAxis); + + m_writer->emit("@workgroup_size("); + for (int ii = 0; ii < kThreadGroupAxisCount; ++ii) + { + if (ii != 0) + m_writer->emit(", "); + m_writer->emit(sizeAlongAxis[ii]); + } + m_writer->emit(")\n"); + } + } + break; + + default: + SLANG_ABORT_COMPILATION("unsupported stage."); + } + +} + +// This is 'function_header' from the WGSL specification +void WGSLSourceEmitter::emitFuncHeaderImpl(IRFunc* func) +{ + Slang::IRType * resultType = func->getResultType(); + auto name = getName(func); + + m_writer->emit("fn "); + m_writer->emit(name); + + emitSimpleFuncParamsImpl(func); + + // An absence of return type is expressed by skipping the optional '->' part of the + // header. + if (resultType->getOp() != kIROp_VoidType) + { + m_writer->emit(" -> "); + emitType(resultType); + } +} + +void WGSLSourceEmitter::emitSimpleFuncParamImpl(IRParam* param) +{ + if (auto sysSemanticDecor = param->findDecoration<IRTargetSystemValueDecoration>()) + { + m_writer->emit("@builtin("); + m_writer->emit(sysSemanticDecor->getSemantic()); + m_writer->emit(")"); + } + + CLikeSourceEmitter::emitSimpleFuncParamImpl(param); +} + +void WGSLSourceEmitter::emitMatrixType( + IRType *const elementType, const IRIntegerValue& rowCountWGSL, + const IRIntegerValue& colCountWGSL + ) +{ + // WGSL uses CxR convention + m_writer->emit("mat"); + m_writer->emit(colCountWGSL); + m_writer->emit("x"); + m_writer->emit(rowCountWGSL); + m_writer->emit("<"); + emitType(elementType); + m_writer->emit(">"); +} + +void WGSLSourceEmitter::emitStructDeclarationSeparatorImpl() +{ + m_writer->emit(","); +} + +static bool isPowerOf2(const uint32_t n) +{ + return (n != 0U) && ((n - 1U) & n) == 0U; +} + +void WGSLSourceEmitter::emitStructFieldAttributes( + IRStructType * structType, IRStructField * field + ) +{ + // Tint emits errors unless we explicitly spell out the layout in some cases, so emit + // offset and align attribtues for all fields. + IRSizeAndAlignmentDecoration *const sizeAndAlignmentDecoration = + structType->findDecoration<IRSizeAndAlignmentDecoration>(); + // NullDifferential struct doesn't have size and alignment decoration + if (sizeAndAlignmentDecoration == nullptr) + return; + SLANG_ASSERT(sizeAndAlignmentDecoration->getAlignment() > IRIntegerValue{0}); + SLANG_ASSERT( + sizeAndAlignmentDecoration->getAlignment() <= IRIntegerValue{UINT32_MAX} + ); + const uint32_t structAlignment = + static_cast<uint32_t>(sizeAndAlignmentDecoration->getAlignment()); + IROffsetDecoration *const fieldOffsetDecoration = + field->findDecoration<IROffsetDecoration>(); + SLANG_ASSERT(fieldOffsetDecoration->getOffset() >= IRIntegerValue{0}); + SLANG_ASSERT(fieldOffsetDecoration->getOffset() <= IRIntegerValue{UINT32_MAX}); + SLANG_ASSERT(isPowerOf2(structAlignment)); + const uint32_t fieldOffset = + static_cast<uint32_t>(fieldOffsetDecoration->getOffset()); + // Alignment is GCD(fieldOffset, structAlignment) + // TODO: Use builtin/intrinsic (e.g. __builtin_ffs) + uint32_t fieldAlignment = 1U; + while (((fieldAlignment & (structAlignment | fieldOffset)) == 0U)) + fieldAlignment = fieldAlignment << 1U; + + m_writer->emit("@align("); + m_writer->emit(fieldAlignment); + m_writer->emit(")"); +} + +bool WGSLSourceEmitter::isPointerSyntaxRequiredImpl(IRInst* inst) +{ + // Structured buffers are mapped to 'array' types, which don't need dereferencing + if (inst->getOp() == kIROp_RWStructuredBufferGetElementPtr) + return false; + + // Don't emit "->" to access fields in resource structs + if (inst->getOp() == kIROp_FieldAddress) + return false; + + // Don't emit "*" to access fields in resource structs + if (inst->getOp() == kIROp_GlobalParam) + return false; + + // Emit 'globalVar' instead of "*&globalVar" + if (inst->getOp() == kIROp_GlobalVar) + return false; + + return true; +} + +void WGSLSourceEmitter::emit(const AddressSpace addressSpace) +{ + switch (addressSpace) + { + case AddressSpace::Uniform: + m_writer->emit("uniform"); + break; + + case AddressSpace::StorageBuffer: + m_writer->emit("storage"); + break; + + case AddressSpace::Generic: + m_writer->emit("function"); + break; + + case AddressSpace::ThreadLocal: + m_writer->emit("private"); + break; + + case AddressSpace::GroupShared: + m_writer->emit("workgroup"); + break; + } +} + +void WGSLSourceEmitter::emitSimpleTypeImpl(IRType* type) +{ + switch (type->getOp()) + { + + case kIROp_HLSLRWStructuredBufferType: + { + auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type); + m_writer->emit("ptr<"); + emit(AddressSpace::StorageBuffer); + m_writer->emit(", "); + m_writer->emit("array"); + m_writer->emit("<"); + emitType(structuredBufferType->getElementType()); + m_writer->emit(">"); + m_writer->emit(", read_write"); + m_writer->emit(">"); + } + break; + + case kIROp_HLSLStructuredBufferType: + { + auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type); + m_writer->emit("ptr<"); + emit(AddressSpace::StorageBuffer); + m_writer->emit(", "); + m_writer->emit("array"); + m_writer->emit("<"); + emitType(structuredBufferType->getElementType()); + m_writer->emit(">"); + m_writer->emit(", read"); + m_writer->emit(">"); + } + break; + + case kIROp_VoidType: + { + // There is no void type in WGSL. + // A return type of "void" is expressed by skipping the end part of the + // 'function_header' term: + // " + // function_header : + // 'fn' ident '(' param_list ? ')' + // ( '->' attribute * template_elaborated_ident ) ? + // " + // In other words, in WGSL we should never even get to the point where we're + // asking to emit 'void'. + SLANG_UNEXPECTED("'void' type emitted"); + return; + } + + case kIROp_FloatType: + m_writer->emit("f32"); + break; + case kIROp_DoubleType: + // There is no "f64" type in WGSL + SLANG_UNEXPECTED("'double' type emitted"); + break; + case kIROp_Int8Type: + case kIROp_UInt8Type: + // There is no "[i|u]8" type in WGSL + SLANG_UNEXPECTED("8 bit integer type emitted"); + break; + case kIROp_HalfType: + m_f16ExtensionEnabled = true; + m_writer->emit("f16"); + break; + case kIROp_BoolType: + m_writer->emit("bool"); + break; + case kIROp_IntType: + m_writer->emit("i32"); + break; + case kIROp_UIntType: + m_writer->emit("u32"); + break; + case kIROp_UInt64Type: + { + m_writer->emit(getDefaultBuiltinTypeName(type->getOp())); + return; + } + case kIROp_Int16Type: + case kIROp_UInt16Type: + SLANG_UNEXPECTED("16 bit integer value emitted"); + return; + case kIROp_Int64Type: + case kIROp_IntPtrType: + m_writer->emit("i64"); + return; + case kIROp_UIntPtrType: + m_writer->emit("u64"); + return; + case kIROp_StructType: + m_writer->emit(getName(type)); + return; + + case kIROp_VectorType: + { + auto vecType = (IRVectorType*)type; + emitVectorTypeNameImpl( + vecType->getElementType(), getIntVal(vecType->getElementCount()) + ); + return; + } + case kIROp_MatrixType: + { + auto matType = (IRMatrixType*)type; + // We map matrices in Slang to WGSL matrices that represent the transpose. + // (See note on "terminology reversal".) + const IRIntegerValue colCountWGSL = getIntVal(matType->getRowCount()); + const IRIntegerValue rowCountWGSL = getIntVal(matType->getColumnCount()); + emitMatrixType(matType->getElementType(), rowCountWGSL, colCountWGSL); + return; + } + case kIROp_SamplerStateType: + { + m_writer->emit("sampler"); + return; + } + + case kIROp_SamplerComparisonStateType: + { + m_writer->emit("sampler_comparison"); + return; + } + + case kIROp_PtrType: + case kIROp_InOutType: + case kIROp_OutType: + case kIROp_RefType: + case kIROp_ConstRefType: + { + auto ptrType = cast<IRPtrTypeBase>(type); + m_writer->emit("ptr<"); + emit((AddressSpace)ptrType->getAddressSpace()); + m_writer->emit(", "); + emitType((IRType*)ptrType->getValueType()); + m_writer->emit(">"); + return; + } + + case kIROp_ArrayType: + { + m_writer->emit("array<"); + emitType((IRType*)type->getOperand(0)); + m_writer->emit(", "); + emitVal(type->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(">"); + return; + } + default: + break; + + } + +} + +void WGSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout) +{ + + for (auto attr : layout->getOffsetAttrs()) + { + LayoutResourceKind kind = attr->getResourceKind(); + + // TODO: + // This is not correct. For the moment this is just here as a hack to make + // @binding and @group unique, so that we can pass WGSL compile tests. + // This will have to be revisited when we actually want to supply resources to + // shaders. + if (kind == LayoutResourceKind::DescriptorTableSlot) + { + m_writer->emit("@binding("); + m_writer->emit(attr->getOffset()); + m_writer->emit(") "); + m_writer->emit("@group("); + m_writer->emit(attr->getSpace()); + m_writer->emit(") "); + + return; + } + } + +} + +void WGSLSourceEmitter::emitVarKeywordImpl(IRType * type, const bool isConstant) +{ + if (isConstant) + m_writer->emit("const"); + else + m_writer->emit("var"); + if (type->getOp() == kIROp_HLSLRWStructuredBufferType) + { + m_writer->emit("<"); + m_writer->emit("storage, read_write"); + m_writer->emit(">"); + } + else if (type->getOp() == kIROp_HLSLStructuredBufferType) + { + m_writer->emit("<"); + m_writer->emit("storage, read"); + m_writer->emit(">"); + } +} + +void WGSLSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator) +{ + // C-like languages bake array-ness, pointer-ness and reference-ness into the + // declarator, which happens in the default _emitType implementation. + // WGSL on the other hand, don't have special syntax -- these are just types. + switch (type->getOp()) + { + case kIROp_ArrayType: + case kIROp_AttributedType: + case kIROp_UnsizedArrayType: + emitSimpleTypeAndDeclarator(type, declarator); + break; + default: + CLikeSourceEmitter::_emitType(type, declarator); + break; + } +} + +void WGSLSourceEmitter::emitDeclaratorImpl(DeclaratorInfo* declarator) +{ + if (!declarator) return; + + m_writer->emit(" "); + + switch (declarator->flavor) + { + case DeclaratorInfo::Flavor::Name: + { + auto nameDeclarator = (NameDeclaratorInfo*)declarator; + m_writer->emitName(*nameDeclarator->nameAndLoc); + } + break; + + case DeclaratorInfo::Flavor::SizedArray: + { + // Sized arrays are just types (array<T, N>) in WGSL -- they are not + // supported at the syntax level + // https://www.w3.org/TR/WGSL/#array + SLANG_UNEXPECTED("Sized array declarator"); + } + break; + + case DeclaratorInfo::Flavor::UnsizedArray: + { + // Unsized arrays are just types (array<T>) in WGSL -- they are not + // supported at the syntax level + // https://www.w3.org/TR/WGSL/#array + SLANG_UNEXPECTED("Unsized array declarator"); + } + break; + + case DeclaratorInfo::Flavor::Ptr: + { + // Pointers (ptr<AS,T,AM>) are just types in WGSL -- they are not supported at + // the syntax level + // https://www.w3.org/TR/WGSL/#ref-ptr-types + SLANG_UNEXPECTED("Pointer declarator"); + } + break; + + case DeclaratorInfo::Flavor::Ref: + { + // References (ref<AS,T,AM>) are just types in WGSL -- they are not supported + // at the syntax level + // https://www.w3.org/TR/WGSL/#ref-ptr-types + SLANG_UNEXPECTED("Reference declarator"); + } + break; + + case DeclaratorInfo::Flavor::LiteralSizedArray: + { + // Sized arrays are just types (array<T, N>) in WGSL -- they are not supported + // at the syntax level + // https://www.w3.org/TR/WGSL/#array + SLANG_UNEXPECTED("Literal-sized array declarator"); + } + break; + + case DeclaratorInfo::Flavor::Attributed: + { + auto attributedDeclarator = (AttributedDeclaratorInfo*)declarator; + auto instWithAttributes = attributedDeclarator->instWithAttributes; + for (auto attr : instWithAttributes->getAllAttrs()) + { + _emitPostfixTypeAttr(attr); + } + emitDeclarator(attributedDeclarator->next); + } + break; + + default: + SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unknown declarator flavor"); + break; + } +} + +void WGSLSourceEmitter::emitSimpleTypeAndDeclaratorImpl( + IRType* type, DeclaratorInfo* declarator + ) +{ + if (declarator) + { + emitDeclarator(declarator); + m_writer->emit(" : "); + } + emitSimpleType(type); +} + +void WGSLSourceEmitter::emitSimpleValueImpl(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_IntLit: + { + auto litInst = static_cast<IRConstant*>(inst); + + IRBasicType* type = as<IRBasicType>(inst->getDataType()); + if (type) + { + switch (type->getBaseType()) + { + default: + + case BaseType::Int8: + case BaseType::UInt8: + { + SLANG_UNEXPECTED("8 bit integer value emitted"); + break; + } + case BaseType::Int16: + case BaseType::UInt16: + { + SLANG_UNEXPECTED("16 bit integer value emitted"); + break; + } + case BaseType::Int: + { + m_writer->emit("i32("); + m_writer->emit(int32_t(litInst->value.intVal)); + m_writer->emit(")"); + return; + } + case BaseType::UInt: + { + m_writer->emit("u32("); + m_writer->emit(UInt(uint32_t(litInst->value.intVal))); + m_writer->emit(")"); + break; + } + case BaseType::Int64: + { + m_writer->emit("i64("); + m_writer->emitInt64(int64_t(litInst->value.intVal)); + m_writer->emit(")"); + break; + } + case BaseType::UInt64: + { + m_writer->emit("u64("); + SLANG_COMPILE_TIME_ASSERT( + sizeof(litInst->value.intVal) >= sizeof(uint64_t) + ); + m_writer->emitUInt64(uint64_t(litInst->value.intVal)); + m_writer->emit(")"); + break; + } + case BaseType::IntPtr: + { +#if SLANG_PTR_IS_64 + m_writer->emit("i64("); + m_writer->emitInt64(int64_t(litInst->value.intVal)); + m_writer->emit(")"); +#else + m_writer->emit("i32("); + m_writer->emit(int(litInst->value.intVal)); + m_writer->emit(")"); +#endif + break; + } + case BaseType::UIntPtr: + { +#if SLANG_PTR_IS_64 + m_writer->emit("u64("); + m_writer->emitUInt64(uint64_t(litInst->value.intVal)); + m_writer->emit(")"); +#else + m_writer->emit("u32("); + m_writer->emit(UInt(uint32_t(litInst->value.intVal))); + m_writer->emit(")"); +#endif + break; + } + + } + } + else + { + // If no type... just output what we have + m_writer->emit(litInst->value.intVal); + } + break; + } + + case kIROp_FloatLit: + { + auto litInst = static_cast<IRConstant*>(inst); + + IRBasicType* type = as<IRBasicType>(inst->getDataType()); + if (type) + { + switch (type->getBaseType()) + { + default: + + case BaseType::Half: + { + m_writer->emit(litInst->value.floatVal); + m_writer->emit("h"); + m_f16ExtensionEnabled = true; + } + break; + + case BaseType::Float: + { + m_writer->emit(litInst->value.floatVal); + m_writer->emit("f"); + } + break; + + case BaseType::Double: + { + // There is not "f64" in WGSL + SLANG_UNEXPECTED("'double' type emitted"); + } + break; + } + } + else + { + // If no type... just output what we have + m_writer->emit(litInst->value.floatVal); + } + } + break; + + case kIROp_BoolLit: + { + bool val = ((IRConstant*)inst)->value.intVal != 0; + m_writer->emit(val ? "true" : "false"); + } + break; + + default: + SLANG_UNIMPLEMENTED_X("val case for emit"); + break; + } + + +} + +void WGSLSourceEmitter::emitParamTypeImpl(IRType* type, const String& name) +{ + emitType(type, name); +} + +bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) +{ + EmitOpInfo outerPrec = inOuterPrec; + + switch (inst->getOp()) + { + + case kIROp_MakeVectorFromScalar: + { + // In WGSL this is done by calling the vec* overloads listed in [1] + // [1] https://www.w3.org/TR/WGSL/#value-constructor-builtin-function + emitType(inst->getDataType()); + m_writer->emit("("); + auto prec = getInfo(EmitOp::Prefix); + emitOperand(inst->getOperand(0), rightSide(outerPrec, prec)); + m_writer->emit(")"); + return true; + } + break; + + case kIROp_BitCast: + { + // In WGSL there is a built-in bitcast function! + // https://www.w3.org/TR/WGSL/#bitcast-builtin + m_writer->emit("bitcast"); + m_writer->emit("<"); + emitType(inst->getDataType()); + m_writer->emit(">"); + m_writer->emit("("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + break; + + case kIROp_MakeArray: + case kIROp_MakeStruct: + { + // It seems there are currently no designated initializers in WGSL. + // Similarly for array initializers. + // https://github.com/gpuweb/gpuweb/issues/4210 + + // There is a constructor named like the struct/array type itself + auto type = inst->getDataType(); + emitType(type); + m_writer->emit("( "); + UInt argCount = inst->getOperandCount(); + for (UInt aa = 0; aa < argCount; ++aa) + { + if (aa != 0) m_writer->emit(", "); + emitOperand(inst->getOperand(aa), getInfo(EmitOp::General)); + } + m_writer->emit(" )"); + + return true; + } + break; + + case kIROp_MakeArrayFromElement: + { + // It seems there are currently no array initializers in WGSL. + + // There is a constructor named like the array type itself + auto type = inst->getDataType(); + emitType(type); + m_writer->emit("("); + UInt argCount = + (UInt)cast<IRIntLit>( + cast<IRArrayType>(inst->getDataType())->getElementCount() + )->getValue(); + for (UInt aa = 0; aa < argCount; ++aa) + { + if (aa != 0) m_writer->emit(", "); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + } + m_writer->emit(")"); + return true; + } + break; + + case kIROp_StructuredBufferLoad: + case kIROp_RWStructuredBufferLoad: + { + // Structured buffers are just arrays in WGSL + auto base = inst->getOperand(0); + emitOperand(base, outerPrec); + m_writer->emit("["); + emitOperand(inst->getOperand(1), EmitOpInfo()); + m_writer->emit("]"); + return true; + } + break; + + case kIROp_Rsh: + case kIROp_Lsh: + { + // Shift amounts must be an unsigned type in WGSL + // https://www.w3.org/TR/WGSL/#bit-expr + IRInst *const shiftAmount = inst->getOperand(1); + IRType *const shiftAmountType = shiftAmount->getDataType(); + if (shiftAmountType->getOp() == kIROp_IntType) + { + // Dawn complains about "mixing '<<' and '|' requires parenthesis", so let's + // add parenthesis. + m_writer->emit("("); + + const auto emitOp = getEmitOpForOp(inst->getOp()); + const auto info = getInfo(emitOp); + + const bool needClose = maybeEmitParens(outerPrec, info); + emitOperand(inst->getOperand(0), leftSide(outerPrec, info)); + m_writer->emit(" "); + m_writer->emit(info.op); + m_writer->emit(" "); + m_writer->emit("bitcast<u32>("); + emitOperand(inst->getOperand(1), rightSide(outerPrec, info)); + m_writer->emit(")"); + maybeCloseParens(needClose); + + m_writer->emit(")"); + return true; + } + } + break; + + } + + return false; +} + +void WGSLSourceEmitter::emitVectorTypeNameImpl( + IRType* elementType, IRIntegerValue elementCount + ) +{ + + if (elementCount > 1) + { + m_writer->emit("vec"); + m_writer->emit(elementCount); + m_writer->emit("<"); + emitSimpleType(elementType); + m_writer->emit(">"); + } + else + { + emitSimpleType(elementType); + } +} + +void WGSLSourceEmitter::emitOperandImpl(IRInst* inst, const EmitOpInfo& outerPrec) +{ + // In WGSL, the structured buffer types are converted to ptr<AS, array<E>, AM> + // everywhere, except for the global parameter declaration. + // Thus, when these globals are used in expressions, we need an ampersand. + + if (inst->getOp() == kIROp_GlobalParam) + { + switch (inst->getDataType()->getOp()) + { + case kIROp_HLSLStructuredBufferType: + case kIROp_HLSLRWStructuredBufferType: + + m_writer->emit("(&"); + CLikeSourceEmitter::emitOperandImpl(inst, outerPrec); + m_writer->emit(")"); + return; + } + } + + CLikeSourceEmitter::emitOperandImpl(inst, outerPrec); +} + +void WGSLSourceEmitter::emitGlobalParamType(IRType* type, const String& name) +{ + // In WGSL, the structured buffer types are converted to ptr<AS, array<E>, AM> + // everywhere, except for the global parameter declaration. + + switch (type->getOp()) + { + + case kIROp_HLSLStructuredBufferType: + case kIROp_HLSLRWStructuredBufferType: + { + StringSliceLoc nameAndLoc(name.getUnownedSlice()); + NameDeclaratorInfo nameDeclarator(&nameAndLoc); + emitDeclarator(&nameDeclarator); + m_writer->emit(" : "); + auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type); + m_writer->emit("array"); + m_writer->emit("<"); + emitType(structuredBufferType->getElementType()); + m_writer->emit(">"); + } + break; + + default: + + emitType(type, name); + break; + + } + +} + +void WGSLSourceEmitter::emitFrontMatterImpl(TargetRequest* /* targetReq */) +{ + if (m_f16ExtensionEnabled) + { + m_writer->emit("enable f16;\n"); + m_writer->emit("\n"); + } +} + +} // namespace Slang diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h new file mode 100644 index 000000000..dacd11c3d --- /dev/null +++ b/source/slang/slang-emit-wgsl.h @@ -0,0 +1,71 @@ +#pragma once + +#include "slang-emit-c-like.h" + +namespace Slang +{ + +class WGSLSourceEmitter : public CLikeSourceEmitter +{ +public: + + WGSLSourceEmitter(const Desc& desc) + : CLikeSourceEmitter(desc) + {} + + virtual void emitParameterGroupImpl( + IRGlobalParam* varDecl, IRUniformParameterGroupType* type + ) SLANG_OVERRIDE; + virtual void emitEntryPointAttributesImpl( + IRFunc* irFunc, IREntryPointDecoration* entryPointDecor + ) SLANG_OVERRIDE; + virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE; + virtual void emitVectorTypeNameImpl( + IRType* elementType, IRIntegerValue elementCount + ) SLANG_OVERRIDE; + virtual void emitFuncHeaderImpl(IRFunc* func) SLANG_OVERRIDE; + virtual void emitSimpleValueImpl(IRInst* inst) SLANG_OVERRIDE; + virtual bool tryEmitInstExprImpl( + IRInst* inst, const EmitOpInfo& inOuterPrec + ) SLANG_OVERRIDE; + virtual void emitSwitchCaseSelectorsImpl( + IRBasicType *const switchCondition, + const SwitchRegion::Case *const currentCase, + const bool isDefault + ) SLANG_OVERRIDE; + virtual void emitSimpleTypeAndDeclaratorImpl( + IRType* type, DeclaratorInfo* declarator + ) SLANG_OVERRIDE; + virtual void emitVarKeywordImpl(IRType * type, const bool isConstant) SLANG_OVERRIDE; + virtual void emitDeclaratorImpl(DeclaratorInfo* declarator) SLANG_OVERRIDE; + virtual void emitStructDeclarationSeparatorImpl() SLANG_OVERRIDE; + virtual void emitLayoutQualifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE; + virtual void emitSimpleFuncParamImpl(IRParam* param) SLANG_OVERRIDE; + virtual void emitParamTypeImpl(IRType* type, const String& name) SLANG_OVERRIDE; + virtual bool isPointerSyntaxRequiredImpl(IRInst* inst) SLANG_OVERRIDE; + virtual void _emitType(IRType* type, DeclaratorInfo* declarator) SLANG_OVERRIDE; + virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE; + virtual void emitStructFieldAttributes( + IRStructType * structType, IRStructField * field + ) SLANG_OVERRIDE; + virtual void emitGlobalParamType(IRType* type, const String& name) SLANG_OVERRIDE; + virtual void emitOperandImpl( + IRInst* inst, const EmitOpInfo& outerPrec + ) SLANG_OVERRIDE; + + void emit(const AddressSpace addressSpace); + +private: + + // Emit the matrix type with 'rowCountWGSL' WGSL-rows and 'colCountWGSL' WGSL-columns + void emitMatrixType( + IRType *const elementType, + const IRIntegerValue& rowCountWGSL, + const IRIntegerValue& colCountWGSL + ); + + bool m_f16ExtensionEnabled {false}; + +}; + +} // namespace Slang diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index ed9e90462..2ccf075f3 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -31,6 +31,7 @@ #include "slang-ir-glsl-legalize.h" #include "slang-ir-hlsl-legalize.h" #include "slang-ir-metal-legalize.h" +#include "slang-ir-wgsl-legalize.h" #include "slang-ir-insts.h" #include "slang-ir-inline.h" #include "slang-ir-legalize-array-return-type.h" @@ -101,6 +102,7 @@ #include "slang-emit-glsl.h" #include "slang-emit-hlsl.h" #include "slang-emit-metal.h" +#include "slang-emit-wgsl.h" #include "slang-emit-cpp.h" #include "slang-emit-cuda.h" #include "slang-emit-torch.h" @@ -1234,6 +1236,12 @@ Result linkAndOptimizeIR( } break; + case CodeGenTarget::WGSL: + { + legalizeIRForWGSL(irModule, sink); + } + break; + default: break; } @@ -1535,15 +1543,28 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr auto targetProgram = getTargetProgram(); auto lineDirectiveMode = targetProgram->getOptionSet().getEnumOption<LineDirectiveMode>(CompilerOptionName::LineDirectiveMode); - // To try to make the default behavior reasonable, we will - // always use C-style line directives (to give the user - // good source locations on error messages from downstream - // compilers) *unless* they requested raw GLSL as the - // output (in which case we want to maximize compatibility - // with downstream tools). - if (lineDirectiveMode == LineDirectiveMode::Default && targetRequest->getTarget() == CodeGenTarget::GLSL) + // We will generally use C-style line directives in order to give the user good + // source locations on error messages from downstream compilers, but there are + // a few exceptions. + if (lineDirectiveMode == LineDirectiveMode::Default) { - lineDirectiveMode = LineDirectiveMode::GLSL; + + switch(targetRequest->getTarget()) + { + + case CodeGenTarget::GLSL: + // We want to maximize compatibility with downstream tools. + lineDirectiveMode = LineDirectiveMode::GLSL; + break; + + case CodeGenTarget::WGSL: + // WGSL doesn't support line directives. + // See https://github.com/gpuweb/gpuweb/issues/606. + lineDirectiveMode = LineDirectiveMode::None; + break; + + } + } ComPtr<IBoxValue<SourceMap>> sourceMap; @@ -1610,6 +1631,11 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr sourceEmitter = new MetalSourceEmitter(desc); break; } + case SourceLanguage::WGSL: + { + sourceEmitter = new WGSLSourceEmitter(desc); + break; + } default: break; } break; diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 5865d5320..01b1c20de 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -1511,6 +1511,7 @@ static bool doesTargetAllowUnresolvedFuncSymbol(TargetRequest* req) case CodeGenTarget::Metal: case CodeGenTarget::MetalLib: case CodeGenTarget::MetalLibAssembly: + case CodeGenTarget::WGSL: case CodeGenTarget::DXIL: case CodeGenTarget::DXILAssembly: case CodeGenTarget::HostCPPSource: diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index a480ae673..d0ad7483a 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -888,16 +888,19 @@ namespace Slang IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType) { - if (!isKhronosTarget(target->getTargetReq())) - return IRTypeLayoutRules::getNatural(); + if (target->getTargetReq()->getTarget() != CodeGenTarget::WGSL) + { + if (!isKhronosTarget(target->getTargetReq())) + return IRTypeLayoutRules::getNatural(); - // If we are just emitting GLSL, we can just use the general layout rule. - if (!target->shouldEmitSPIRVDirectly()) - return IRTypeLayoutRules::getNatural(); + // If we are just emitting GLSL, we can just use the general layout rule. + if (!target->shouldEmitSPIRVDirectly()) + return IRTypeLayoutRules::getNatural(); - // If the user specified a scalar buffer layout, then just use that. - if (target->getOptionSet().shouldUseScalarLayout()) - return IRTypeLayoutRules::getNatural(); + // If the user specified a scalar buffer layout, then just use that. + if (target->getOptionSet().shouldUseScalarLayout()) + return IRTypeLayoutRules::getNatural(); + } if (target->getOptionSet().shouldUseDXLayout()) { diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp new file mode 100644 index 000000000..e05eba78c --- /dev/null +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -0,0 +1,347 @@ +#include "slang-ir-wgsl-legalize.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-parameter-binding.h" +#include "slang-ir-legalize-varying-params.h" + +namespace Slang +{ + + struct EntryPointInfo + { + IRFunc* entryPointFunc; + IREntryPointDecoration* entryPointDecor; + }; + + struct SystemValLegalizationWorkItem + { + IRInst* var; + String attrName; + UInt attrIndex; + }; + + struct WGSLSystemValueInfo + { + String wgslSystemValueName; + SystemValueSemanticName wgslSystemValueNameEnum; + ShortList<IRType*> permittedTypes; + bool isUnsupported = false; + }; + + struct LegalizeWGSLEntryPointContext + { + LegalizeWGSLEntryPointContext(DiagnosticSink* sink, IRModule* module) : + m_sink(sink), m_module(module) {} + + DiagnosticSink* m_sink; + IRModule* m_module; + + std::optional<SystemValLegalizationWorkItem> makeSystemValWorkItem(IRInst* var); + void legalizeSystemValue( + EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem + ); + List<SystemValLegalizationWorkItem> collectSystemValFromEntryPoint( + EntryPointInfo entryPoint + ); + void legalizeSystemValueParameters(EntryPointInfo entryPoint); + void legalizeEntryPointForWGSL(EntryPointInfo entryPoint); + IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType); + WGSLSystemValueInfo getSystemValueInfo( + String inSemanticName, String* optionalSemanticIndex, IRInst* parentVar + ); + }; + + IRInst* LegalizeWGSLEntryPointContext::tryConvertValue( + IRBuilder& builder, IRInst* val, IRType* toType + ) + { + auto fromType = val->getFullType(); + if (auto fromVector = as<IRVectorType>(fromType)) + { + if (auto toVector = as<IRVectorType>(toType)) + { + if (fromVector->getElementCount() != toVector->getElementCount()) + { + fromType = + builder.getVectorType( + fromVector->getElementType(), toVector->getElementCount() + ); + val = builder.emitVectorReshape(fromType, val); + } + } + else if (as<IRBasicType>(toType)) + { + UInt index = 0; + val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index); + if (toType->getOp() == kIROp_VoidType) + return nullptr; + } + } + else if (auto fromBasicType = as<IRBasicType>(fromType)) + { + if (fromBasicType->getOp() == kIROp_VoidType) + return nullptr; + if (!as<IRBasicType>(toType)) + return nullptr; + if (toType->getOp() == kIROp_VoidType) + return nullptr; + } + else + { + return nullptr; + } + return builder.emitCast(toType, val); + } + + + WGSLSystemValueInfo LegalizeWGSLEntryPointContext::getSystemValueInfo( + String inSemanticName, String* optionalSemanticIndex, IRInst* parentVar + ) + { + IRBuilder builder(m_module); + WGSLSystemValueInfo result = {}; + UnownedStringSlice semanticName; + UnownedStringSlice semanticIndex; + + auto hasExplicitIndex = + splitNameAndIndex( + inSemanticName.getUnownedSlice(), semanticName, semanticIndex + ); + if (!hasExplicitIndex && optionalSemanticIndex) + semanticIndex = optionalSemanticIndex->getUnownedSlice(); + + result.wgslSystemValueNameEnum = + convertSystemValueSemanticNameToEnum(semanticName); + + switch (result.wgslSystemValueNameEnum) + { + + case SystemValueSemanticName::DispatchThreadID: + { + result.wgslSystemValueName = toSlice("global_invocation_id"); + IRType *const vec3uType { + builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3) + ) + }; + result.permittedTypes.add(vec3uType); + } + break; + + case SystemValueSemanticName::GroupID: + { + result.wgslSystemValueName = toSlice("workgroup_id"); + result.permittedTypes.add( + builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3) + ) + ); + } + break; + + case SystemValueSemanticName::GroupThreadID: + { + result.wgslSystemValueName = toSlice("local_invocation_id"); + result.permittedTypes.add( + builder.getVectorType( + builder.getBasicType(BaseType::UInt), + builder.getIntValue(builder.getIntType(), 3) + ) + ); + } + break; + + case SystemValueSemanticName::GSInstanceID: + { + // No Geometry shaders in WGSL + result.isUnsupported = true; + } + break; + + default: + { + m_sink->diagnose( + parentVar, + Diagnostics::unimplementedSystemValueSemantic, semanticName + ); + return result; + } + + } + + return result; + } + + std::optional<SystemValLegalizationWorkItem> + LegalizeWGSLEntryPointContext::makeSystemValWorkItem(IRInst* var) + { + if (auto semanticDecoration = var->findDecoration<IRSemanticDecoration>()) + { + bool svPrefix = + semanticDecoration->getSemanticName().startsWithCaseInsensitive( + toSlice("sv_") + ); + if (svPrefix) + { + return + { + { + var, + String(semanticDecoration->getSemanticName()).toLower(), + (UInt)semanticDecoration->getSemanticIndex() + } + }; + } + } + + auto layoutDecor = var->findDecoration<IRLayoutDecoration>(); + if (!layoutDecor) + return {}; + auto sysValAttr = layoutDecor->findAttr<IRSystemValueSemanticAttr>(); + if (!sysValAttr) + return {}; + auto semanticName = String(sysValAttr->getName()); + auto sysAttrIndex = sysValAttr->getIndex(); + + return { { var, semanticName, sysAttrIndex } }; + } + + List<SystemValLegalizationWorkItem> + LegalizeWGSLEntryPointContext::collectSystemValFromEntryPoint( + EntryPointInfo entryPoint + ) + { + List<SystemValLegalizationWorkItem> systemValWorkItems; + for (auto param : entryPoint.entryPointFunc->getParams()) + { + auto maybeWorkItem = makeSystemValWorkItem(param); + if (maybeWorkItem.has_value()) + systemValWorkItems.add(std::move(maybeWorkItem.value())); + } + return systemValWorkItems; + } + + void + LegalizeWGSLEntryPointContext::legalizeSystemValue( + EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem + ) + { + IRBuilder builder(entryPoint.entryPointFunc); + + auto var = workItem.var; + auto semanticName = workItem.attrName; + + auto indexAsString = String(workItem.attrIndex); + auto info = getSystemValueInfo(semanticName, &indexAsString, var); + + if (!info.permittedTypes.getCount()) + return; + + builder.addTargetSystemValueDecoration( + var, info.wgslSystemValueName.getUnownedSlice() + ); + + bool varTypeIsPermitted = false; + auto varType = var->getFullType(); + for (auto& permittedType : info.permittedTypes) + { + varTypeIsPermitted = varTypeIsPermitted || permittedType == varType; + } + + if (!varTypeIsPermitted) + { + // Note: we do not currently prefer any conversion + // example: + // * allowed types for semantic: `float4`, `uint4`, `int4` + // * user used, `float2` + // * Slang will equally prefer `float4` to `uint4` to `int4`. + // This means the type may lose data if slang selects `uint4` or `int4`. + bool foundAConversion = false; + for (auto permittedType : info.permittedTypes) + { + var->setFullType(permittedType); + builder.setInsertBefore( + entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst() + ); + + // get uses before we `tryConvertValue` since this creates a new use + List<IRUse*> uses; + for (auto use = var->firstUse; use; use = use->nextUse) + uses.add(use); + + auto convertedValue = tryConvertValue(builder, var, varType); + if (convertedValue == nullptr) + continue; + + foundAConversion = true; + copyNameHintAndDebugDecorations(convertedValue, var); + + for (auto use : uses) + builder.replaceOperand(use, convertedValue); + } + if (!foundAConversion) + { + // If we can't convert the value, report an error. + for (auto permittedType : info.permittedTypes) + { + StringBuilder typeNameSB; + getTypeNameHint(typeNameSB, permittedType); + m_sink->diagnose( + var->sourceLoc, + Diagnostics::systemValueTypeIncompatible, + semanticName, + typeNameSB.produceString() + ); + } + } + } + } + + void LegalizeWGSLEntryPointContext::legalizeSystemValueParameters( + EntryPointInfo entryPoint + ) + { + List<SystemValLegalizationWorkItem> systemValWorkItems = + collectSystemValFromEntryPoint(entryPoint); + + for (auto index = 0; index < systemValWorkItems.getCount(); index++) + { + legalizeSystemValue(entryPoint, systemValWorkItems[index]); + } + } + + void LegalizeWGSLEntryPointContext::legalizeEntryPointForWGSL( + EntryPointInfo entryPoint + ) + { + legalizeSystemValueParameters(entryPoint); + } + + void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) + { + List<EntryPointInfo> entryPoints; + for (auto inst : module->getGlobalInsts()) + { + IRFunc *const func {as<IRFunc>(inst)}; + if (!func) + continue; + IREntryPointDecoration *const entryPointDecor = + func->findDecoration<IREntryPointDecoration>(); + if (!entryPointDecor) + continue; + EntryPointInfo info; + info.entryPointDecor = entryPointDecor; + info.entryPointFunc = func; + entryPoints.add(info); + } + + LegalizeWGSLEntryPointContext context(sink, module); + for (auto entryPoint : entryPoints) + context.legalizeEntryPointForWGSL(entryPoint); + } + +} diff --git a/source/slang/slang-ir-wgsl-legalize.h b/source/slang/slang-ir-wgsl-legalize.h new file mode 100644 index 000000000..462f93204 --- /dev/null +++ b/source/slang/slang-ir-wgsl-legalize.h @@ -0,0 +1,10 @@ +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ + class DiagnosticSink; + + void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink); +} diff --git a/source/slang/slang-profile.h b/source/slang/slang-profile.h index 04d4f5112..178fbddd5 100644 --- a/source/slang/slang-profile.h +++ b/source/slang/slang-profile.h @@ -19,6 +19,7 @@ namespace Slang CUDA = SLANG_SOURCE_LANGUAGE_CUDA, SPIRV = SLANG_SOURCE_LANGUAGE_SPIRV, Metal = SLANG_SOURCE_LANGUAGE_METAL, + WGSL = SLANG_SOURCE_LANGUAGE_WGSL, CountOf = SLANG_SOURCE_LANGUAGE_COUNT_OF, }; diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index f654135a1..2447f5787 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -1831,6 +1831,7 @@ LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targe case CodeGenTarget::GLSL: case CodeGenTarget::SPIRV: case CodeGenTarget::SPIRVAssembly: + case CodeGenTarget::WGSL: return &kGLSLLayoutRulesFamilyImpl; case CodeGenTarget::HostHostCallable: @@ -2141,6 +2142,10 @@ SourceLanguage getIntermediateSourceLanguageForTarget(TargetProgram* targetProgr { return SourceLanguage::Metal; } + case CodeGenTarget::WGSL: + { + return SourceLanguage::WGSL; + } case CodeGenTarget::CSource: { return SourceLanguage::C; diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 91ed3de5f..c78348a86 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1838,6 +1838,10 @@ CapabilitySet TargetRequest::getTargetCaps() atoms.add(CapabilityName::metal); break; + case CodeGenTarget::WGSL: + atoms.add(CapabilityName::wgsl); + break; + default: break; } |
