summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/compiler-core/slang-artifact-desc-util.cpp3
-rw-r--r--source/compiler-core/slang-artifact.h1
-rw-r--r--source/core/slang-type-text-util.cpp1
-rw-r--r--source/slang-record-replay/util/emum-to-string.h1
-rw-r--r--source/slang/hlsl.meta.slang29
-rw-r--r--source/slang/slang-capabilities.capdef15
-rw-r--r--source/slang/slang-compiler.cpp1
-rwxr-xr-xsource/slang/slang-compiler.h1
-rw-r--r--source/slang/slang-doc-markdown-writer.cpp4
-rw-r--r--source/slang/slang-emit-c-like.cpp104
-rw-r--r--source/slang/slang-emit-c-like.h19
-rw-r--r--source/slang/slang-emit-wgsl.cpp1005
-rw-r--r--source/slang/slang-emit-wgsl.h71
-rw-r--r--source/slang/slang-emit.cpp42
-rw-r--r--source/slang/slang-ir-link.cpp1
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp19
-rw-r--r--source/slang/slang-ir-wgsl-legalize.cpp347
-rw-r--r--source/slang/slang-ir-wgsl-legalize.h10
-rw-r--r--source/slang/slang-profile.h1
-rw-r--r--source/slang/slang-type-layout.cpp5
-rw-r--r--source/slang/slang.cpp4
21 files changed, 1628 insertions, 56 deletions
diff --git a/source/compiler-core/slang-artifact-desc-util.cpp b/source/compiler-core/slang-artifact-desc-util.cpp
index a4190992c..9794cc90e 100644
--- a/source/compiler-core/slang-artifact-desc-util.cpp
+++ b/source/compiler-core/slang-artifact-desc-util.cpp
@@ -197,6 +197,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactKind, SLANG_ARTIFACT_KIND, SLANG_ARTIFACT_KIND_E
x(CUDA, Source) \
x(Metal, Source) \
x(Slang, Source) \
+ x(WGSL, Source) \
x(KernelLike, Base) \
x(DXIL, KernelLike) \
x(DXBC, KernelLike) \
@@ -288,6 +289,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL
case SLANG_METAL: return Desc::make(Kind::Source, Payload::Metal, Style::Kernel, 0);
case SLANG_METAL_LIB: return Desc::make(Kind::Executable, Payload::MetalAIR, Style::Kernel, 0);
case SLANG_METAL_LIB_ASM: return Desc::make(Kind::Assembly, Payload::MetalAIR, Style::Kernel, 0);
+ case SLANG_WGSL: return Desc::make(Kind::Source, Payload::WGSL, Style::Kernel, 0);
default: break;
}
@@ -330,6 +332,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL
case Payload::Cpp: return (desc.style == Style::Host) ? SLANG_HOST_CPP_SOURCE : SLANG_CPP_SOURCE;
case Payload::CUDA: return SLANG_CUDA_SOURCE;
case Payload::Metal: return SLANG_METAL;
+ case Payload::WGSL: return SLANG_WGSL;
default: break;
}
break;
diff --git a/source/compiler-core/slang-artifact.h b/source/compiler-core/slang-artifact.h
index 400c85b2e..6d65aafba 100644
--- a/source/compiler-core/slang-artifact.h
+++ b/source/compiler-core/slang-artifact.h
@@ -143,6 +143,7 @@ enum class ArtifactPayload : uint8_t
CUDA, ///< CUDA source
Metal, ///< Metal source
Slang, ///< Slang source
+ WGSL, ///< WGSL source
KernelLike, ///< GPU Kernel like
diff --git a/source/core/slang-type-text-util.cpp b/source/core/slang-type-text-util.cpp
index 9fa91abf6..9f9deb92c 100644
--- a/source/core/slang-type-text-util.cpp
+++ b/source/core/slang-type-text-util.cpp
@@ -63,6 +63,7 @@ static const TypeTextUtil::CompileTargetInfo s_compileTargetInfos[] =
{ SLANG_METAL, "metal", "metal", "Metal shader source" },
{ SLANG_METAL_LIB, "metallib", "metallib", "Metal Library Bytecode" },
{ SLANG_METAL_LIB_ASM, "metallib-asm" "metallib-asm", "Metal Library Bytecode assembly" },
+ { SLANG_WGSL, "wgsl", "wgsl", "WebGPU shading language source" },
};
static const NamesDescriptionValue s_languageInfos[] =
diff --git a/source/slang-record-replay/util/emum-to-string.h b/source/slang-record-replay/util/emum-to-string.h
index 7a7952555..7226edc04 100644
--- a/source/slang-record-replay/util/emum-to-string.h
+++ b/source/slang-record-replay/util/emum-to-string.h
@@ -34,6 +34,7 @@ namespace SlangRecord
CASE(SLANG_METAL_LIB);
CASE(SLANG_METAL_LIB_ASM);
CASE(SLANG_HOST_SHARED_LIBRARY);
+ CASE(SLANG_WGSL);
CASE(SLANG_TARGET_COUNT_OF);
default:
Slang::StringBuilder str;
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 10a6254c1..a8241bf73 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -5668,7 +5668,7 @@ vector<T,N> acosh(vector<T,N> x)
// Test if all components are non-zero (HLSL SM 1.0)
__generic<T : __BuiltinType>
[__readNone]
-[require(cpp_cuda_glsl_hlsl_metal_spirv)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
bool all(T x)
{
__target_switch
@@ -5679,6 +5679,8 @@ bool all(T x)
__intrinsic_asm "all";
case metal:
__intrinsic_asm "all";
+ case wgsl:
+ __intrinsic_asm "all";
case spirv:
let zero = __default<T>();
if (__isInt<T>())
@@ -5806,7 +5808,7 @@ int3 WorkgroupSize();
__generic<T : __BuiltinType>
[__readNone]
-[require(cpp_cuda_glsl_hlsl_metal_spirv)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
bool any(T x)
{
__target_switch
@@ -5817,6 +5819,8 @@ bool any(T x)
__intrinsic_asm "any";
case metal:
__intrinsic_asm "any";
+ case wgsl:
+ __intrinsic_asm "any";
case spirv:
let zero = __default<T>();
if (__isInt<T>())
@@ -6142,7 +6146,7 @@ vector<T,N> asinh(vector<T,N> x)
// Reinterpret bits as an int (HLSL SM 4.0)
[__readNone]
-[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)]
int asint(float x)
{
__target_switch
@@ -6152,6 +6156,7 @@ int asint(float x)
case glsl: __intrinsic_asm "floatBitsToInt";
case hlsl: __intrinsic_asm "asint";
case metal: __intrinsic_asm "as_type<$TR>($0)";
+ case wgsl: __intrinsic_asm "bitcast<$TR>($0)";
case spirv: return spirv_asm {
OpBitcast $$int result $x
};
@@ -6285,7 +6290,7 @@ void asuint(double value, out uint lowbits, out uint highbits)
// Reinterpret bits as a uint (HLSL SM 4.0)
[__readNone]
-[require(cpp_cuda_glsl_hlsl_metal_spirv, shader5_sm_4_0)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_4_0)]
uint asuint(float x)
{
__target_switch
@@ -6295,6 +6300,7 @@ uint asuint(float x)
case glsl: __intrinsic_asm "floatBitsToUint";
case hlsl: __intrinsic_asm "asuint";
case metal: __intrinsic_asm "as_type<$TR>($0)";
+ case wgsl: __intrinsic_asm "bitcast<$TR>($0)";
case spirv: return spirv_asm {
OpBitcast $$uint result $x
};
@@ -7025,7 +7031,7 @@ void clip(matrix<T,N,M> x)
// Cosine
__generic<T : __BuiltinFloatingPointType>
[__readNone]
-[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)]
T cos(T x)
{
__target_switch
@@ -7035,6 +7041,7 @@ T cos(T x)
case glsl: __intrinsic_asm "cos";
case hlsl: __intrinsic_asm "cos";
case metal: __intrinsic_asm "cos";
+ case wgsl: __intrinsic_asm "cos";
case spirv: return spirv_asm {
OpExtInst $$T result glsl450 Cos $x
};
@@ -10427,7 +10434,7 @@ matrix<T, N, M> mad(matrix<T, N, M> mvalue, matrix<T, N, M> avalue, matrix<T, N,
// maximum
__generic<T : __BuiltinIntegerType>
[__readNone]
-[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)]
T max(T x, T y)
{
// Note: a stdlib implementation of `max` (or `min`) will require splitting
@@ -10440,6 +10447,7 @@ T max(T x, T y)
case hlsl: __intrinsic_asm "max";
case glsl: __intrinsic_asm "max";
case metal: __intrinsic_asm "max";
+ case wgsl: __intrinsic_asm "max";
case cuda: __intrinsic_asm "$P_max($0, $1)";
case cpp: __intrinsic_asm "$P_max($0, $1)";
case spirv:
@@ -10656,7 +10664,7 @@ vector<T,N> fmax3(vector<T,N> x, vector<T,N> y, vector<T,N> z)
// minimum
__generic<T : __BuiltinIntegerType>
[__readNone]
-[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)]
T min(T x, T y)
{
__target_switch
@@ -10664,6 +10672,7 @@ T min(T x, T y)
case hlsl:
case glsl:
case metal:
+ case wgsl:
__intrinsic_asm "min";
case cuda:
case cpp:
@@ -11103,13 +11112,14 @@ T mul(vector<T, N> x, vector<T, N> y)
// vector-matrix
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[__readNone]
-[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)]
vector<T, M> mul(vector<T, N> left, matrix<T, N, M> right)
{
__target_switch
{
case glsl: __intrinsic_asm "($1 * $0)";
case metal: __intrinsic_asm "($1 * $0)";
+ case wgsl: __intrinsic_asm "($1 * $0)";
case hlsl: __intrinsic_asm "mul";
case spirv: return spirv_asm {
OpMatrixTimesVector $$vector<T, M> result $right $left
@@ -12166,7 +12176,7 @@ matrix<int, N, M> sign(matrix<T, N, M> x)
__generic<T : __BuiltinFloatingPointType>
[__readNone]
-[require(cpp_cuda_glsl_hlsl_metal_spirv, sm_4_0_version)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)]
T sin(T x)
{
__target_switch
@@ -12176,6 +12186,7 @@ T sin(T x)
case glsl: __intrinsic_asm "sin";
case hlsl: __intrinsic_asm "sin";
case metal: __intrinsic_asm "sin";
+ case wgsl: __intrinsic_asm "sin";
case spirv: return spirv_asm {
OpExtInst $$T result glsl450 Sin $x
};
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index a173a332f..9e9b94151 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -111,6 +111,10 @@ def metal : target + textualTarget;
/// [Target]
def spirv : target;
+/// Represents the WebGPU shading language code generation target.
+/// [Target]
+def wgsl : target + textualTarget;
+
// Capabilities that stand for target SPIR-V versions for the GLSL backend.
// These are not compilation targets. We will convert `_spirv_*` to `glsl_spirv_*` during compilation.
@@ -228,15 +232,15 @@ def _cuda_sm_9_0 : _cuda_sm_8_0;
/// All code-gen targets
/// [Compound]
-alias any_target = hlsl | metal | glsl | c | cpp | cuda | spirv;
+alias any_target = hlsl | metal | glsl | c | cpp | cuda | spirv | wgsl;
/// All non-asm code-gen targets
/// [Compound]
-alias any_textual_target = hlsl | metal | glsl | c | cpp | cuda;
+alias any_textual_target = hlsl | metal | glsl | c | cpp | cuda | wgsl;
/// All slang-gfx compatible code-gen targets
/// [Compound]
-alias any_gfx_target = hlsl | metal | glsl | spirv;
+alias any_gfx_target = hlsl | metal | glsl | spirv | wgsl;
/// All "cpp syntax" code-gen targets
/// [Compound]
@@ -266,6 +270,10 @@ alias cpp_cuda_glsl_hlsl_spirv = cpp | cuda | glsl | hlsl | spirv;
/// [Compound]
alias cpp_cuda_glsl_hlsl_metal_spirv = cpp | cuda | glsl | hlsl | metal | spirv;
+/// CPP, CUDA, GLSL, HLSL, Metal, SPIRV and WGSL code-gen targets
+/// [Compound]
+alias cpp_cuda_glsl_hlsl_metal_spirv_wgsl = cpp | cuda | glsl | hlsl | metal | spirv | wgsl;
+
/// CPP, CUDA, and HLSL code-gen targets
/// [Compound]
alias cpp_cuda_hlsl = cpp | cuda | hlsl;
@@ -1178,6 +1186,7 @@ alias sm_4_0_version = _sm_4_0
| spirv_1_0
| _cuda_sm_2_0
| metal
+ | wgsl
| cpp
;
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 4bb420fa7..541085b4e 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -1715,6 +1715,7 @@ namespace Slang
case CodeGenTarget::PyTorchCppBinding:
case CodeGenTarget::CSource:
case CodeGenTarget::Metal:
+ case CodeGenTarget::WGSL:
{
RefPtr<ExtensionTracker> extensionTracker = _newExtensionTracker(target);
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index b8ee4dc9c..62e4c5f4a 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -94,6 +94,7 @@ namespace Slang
Metal = SLANG_METAL,
MetalLib = SLANG_METAL_LIB,
MetalLibAssembly = SLANG_METAL_LIB_ASM,
+ WGSL = SLANG_WGSL,
CountOf = SLANG_TARGET_COUNT_OF,
};
diff --git a/source/slang/slang-doc-markdown-writer.cpp b/source/slang/slang-doc-markdown-writer.cpp
index c9dd6d9c8..e32a738c7 100644
--- a/source/slang/slang-doc-markdown-writer.cpp
+++ b/source/slang/slang-doc-markdown-writer.cpp
@@ -483,6 +483,10 @@ static DocMarkdownWriter::Requirement _getRequirementFromTargetToken(const Token
{
return Requirement{ CodeGenTarget::Metal, targetName };
}
+ else if (isCapabilityDerivedFrom(targetCap, CapabilityAtom::wgsl))
+ {
+ return Requirement{ CodeGenTarget::WGSL, targetName };
+ }
return Requirement{ CodeGenTarget::Unknown, String() };
}
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 1893929f8..caf3613a7 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -95,6 +95,10 @@ struct CLikeSourceEmitter::ComputeEmitActionsContext
{
return SourceLanguage::Metal;
}
+ case CodeGenTarget::WGSL:
+ {
+ return SourceLanguage::WGSL;
+ }
}
}
@@ -151,7 +155,7 @@ void CLikeSourceEmitter::ensureTypePrelude(IRType* type)
}
}
-void CLikeSourceEmitter::emitDeclarator(DeclaratorInfo* declarator)
+void CLikeSourceEmitter::emitDeclaratorImpl(DeclaratorInfo* declarator)
{
if (!declarator) return;
@@ -341,13 +345,18 @@ void CLikeSourceEmitter::_emitPostfixTypeAttr(IRAttr* attr)
// we may need to handle it here.
}
+void CLikeSourceEmitter::emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator)
+{
+ emitSimpleType(type);
+ emitDeclarator(declarator);
+}
+
void CLikeSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator)
{
switch (type->getOp())
{
default:
- emitSimpleType(type);
- emitDeclarator(declarator);
+ emitSimpleTypeAndDeclarator(type, declarator);
break;
case kIROp_RateQualifiedType:
@@ -648,7 +657,7 @@ bool CLikeSourceEmitter::maybeEmitParens(EmitOpInfo& outerPrec, const EmitOpInfo
bool needParens = (prec.leftPrecedence <= outerPrec.leftPrecedence)
|| (prec.rightPrecedence <= outerPrec.rightPrecedence);
- // While Slang correctly removes some of parentheses, DXC prints warnings
+ // While Slang correctly removes some of parentheses, many compilers print warnings
// for common mistakes when parentheses are not used with certain combinations
// of the operations. We emit parentheses to avoid the warnings.
//
@@ -676,6 +685,12 @@ bool CLikeSourceEmitter::maybeEmitParens(EmitOpInfo& outerPrec, const EmitOpInfo
{
needParens = true;
}
+ // a + b & c => (a + b) & c
+ else if (prec.rightPrecedence == EPrecedence::kEPrecedence_Additive_Right
+ && outerPrec.rightPrecedence == EPrecedence::kEPrecedence_BitAnd_Left)
+ {
+ needParens = true;
+ }
if (needParens)
{
@@ -1657,11 +1672,16 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
return true;
}
+bool CLikeSourceEmitter::isPointerSyntaxRequiredImpl(IRInst* /* inst */)
+{
+ return doesTargetSupportPtrTypes();
+}
+
void CLikeSourceEmitter::emitDereferenceOperand(IRInst* inst, EmitOpInfo const& outerPrec)
{
EmitOpInfo newOuterPrec = outerPrec;
- if (doesTargetSupportPtrTypes())
+ if (isPointerSyntaxRequiredImpl(inst))
{
switch (inst->getOp())
{
@@ -1760,7 +1780,7 @@ void CLikeSourceEmitter::emitDereferenceOperand(IRInst* inst, EmitOpInfo const&
void CLikeSourceEmitter::emitVarExpr(IRInst* inst, EmitOpInfo const& outerPrec)
{
- if (doesTargetSupportPtrTypes())
+ if (isPointerSyntaxRequiredImpl(inst))
{
auto prec = getInfo(EmitOp::Prefix);
auto newOuterPrec = outerPrec;
@@ -1842,7 +1862,8 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst)
emitRateQualifiers(inst);
- if(as<IRModuleInst>(inst->getParent()))
+ bool isConstant(as<IRModuleInst>(inst->getParent()));
+ if(isConstant)
{
// "Ordinary" instructions at module scope are constants
@@ -1857,6 +1878,9 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst)
case SourceLanguage::Metal:
m_writer->emit("constant ");
break;
+ case SourceLanguage::WGSL:
+ // This is handled by emitVarKeyword, below
+ break;
default:
m_writer->emit("const ");
break;
@@ -1864,6 +1888,8 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst)
}
+ emitVarKeyword(type, isConstant);
+
emitType(type, getName(inst));
m_writer->emit(" = ");
}
@@ -2297,7 +2323,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
IRFieldAddress* ii = (IRFieldAddress*) inst;
- if (doesTargetSupportPtrTypes())
+ if (isPointerSyntaxRequiredImpl(inst))
{
auto prec = getInfo(EmitOp::Prefix);
needClose = maybeEmitParens(outerPrec, prec);
@@ -3117,6 +3143,8 @@ void CLikeSourceEmitter::_emitStoreImpl(IRStore* store)
void CLikeSourceEmitter::_emitInstAsDefaultInitializedVar(IRInst* inst, IRType* type)
{
+ emitVarKeyword(type, /* isConstant */ false);
+
emitType(type, getName(inst));
// On targets that support empty initializers, we will emit it.
@@ -3178,6 +3206,20 @@ void CLikeSourceEmitter::emitLayoutSemantics(IRInst* inst, char const* uniformSe
emitLayoutSemanticsImpl(inst, uniformSemanticSpelling, EmitLayoutSemanticOption::kPostType);
}
+void CLikeSourceEmitter::emitSwitchCaseSelectorsImpl(IRBasicType *const /* switchCondition */, const SwitchRegion::Case *const currentCase, const bool isDefault)
+{
+ for(auto caseVal : currentCase->values)
+ {
+ m_writer->emit("case ");
+ emitOperand(caseVal, getInfo(EmitOp::General));
+ m_writer->emit(":\n");
+ }
+ if(isDefault)
+ {
+ m_writer->emit("default:\n");
+ }
+}
+
void CLikeSourceEmitter::emitRegion(Region* inRegion)
{
// We will use a loop so that we can process sequential (simple)
@@ -3333,17 +3375,9 @@ void CLikeSourceEmitter::emitRegion(Region* inRegion)
auto defaultCase = switchRegion->defaultCase;
for(auto currentCase : switchRegion->cases)
{
- for(auto caseVal : currentCase->values)
- {
- m_writer->emit("case ");
- emitOperand(caseVal, getInfo(EmitOp::General));
- m_writer->emit(":\n");
- }
- if(currentCase.Ptr() == defaultCase)
- {
- m_writer->emit("default:\n");
- }
-
+ const bool isDefault {currentCase.Ptr() == defaultCase};
+ IRBasicType *const switchConditionType {as<IRBasicType>(switchRegion->getCondition()->getDataType())};
+ emitSwitchCaseSelectors(switchConditionType, currentCase.Ptr(), isDefault);
m_writer->indent();
m_writer->emit("{\n");
m_writer->indent();
@@ -3449,9 +3483,16 @@ void CLikeSourceEmitter::emitSimpleFuncParamsImpl(IRFunc* func)
m_writer->emit(")");
}
-void CLikeSourceEmitter::emitSimpleFuncImpl(IRFunc* func)
+void CLikeSourceEmitter::emitFuncHeaderImpl(IRFunc* func)
{
auto resultType = func->getResultType();
+ auto name = getName(func);
+ emitType(resultType, name);
+ emitSimpleFuncParamsImpl(func);
+}
+
+void CLikeSourceEmitter::emitSimpleFuncImpl(IRFunc* func)
+{
// Deal with decorations that need
// to be emitted as attributes
@@ -3467,12 +3508,8 @@ void CLikeSourceEmitter::emitSimpleFuncImpl(IRFunc* func)
emitFunctionPreambleImpl(func);
- auto name = getName(func);
-
emitFuncDecorations(func);
-
- emitType(resultType, name);
- emitSimpleFuncParamsImpl(func);
+ emitFuncHeader(func);
emitSemantics(func);
// TODO: encode declaration vs. definition
@@ -3688,6 +3725,11 @@ void CLikeSourceEmitter::emitStruct(IRStructType* structType)
m_writer->emit(";\n\n");
}
+void CLikeSourceEmitter::emitStructDeclarationSeparatorImpl()
+{
+ m_writer->emit(";");
+}
+
void CLikeSourceEmitter::emitStructDeclarationsBlock(IRStructType* structType, bool allowOffsetLayout)
{
m_writer->emit("\n{\n");
@@ -3716,11 +3758,13 @@ void CLikeSourceEmitter::emitStructDeclarationsBlock(IRStructType* structType, b
emitPackOffsetModifier(fieldKey, fieldType, packOffsetDecoration);
}
}
+ emitStructFieldAttributes(structType, ff);
emitMemoryQualifiers(fieldKey);
emitType(fieldType, getName(fieldKey));
emitSemantics(fieldKey, allowOffsetLayout);
emitPostDeclarationAttributesForType(fieldType);
- m_writer->emit(";\n");
+ emitStructDeclarationSeparator();
+ m_writer->emit("\n");
}
m_writer->dedent();
@@ -3931,6 +3975,8 @@ void CLikeSourceEmitter::emitParameterGroup(IRGlobalParam* varDecl, IRUniformPar
emitParameterGroupImpl(varDecl, type);
}
+void CLikeSourceEmitter::emitVarKeywordImpl(IRType * /* type */, bool /* isConstant */) {}
+
void CLikeSourceEmitter::emitVar(IRVar* varDecl)
{
auto allocatedType = varDecl->getDataType();
@@ -3969,6 +4015,8 @@ void CLikeSourceEmitter::emitVar(IRVar* varDecl)
#endif
emitRateQualifiersAndAddressSpace(varDecl);
+ emitVarKeyword(varType, /* isConstant */ false);
+
emitType(varType, getName(varDecl));
emitSemantics(varDecl);
@@ -4099,6 +4147,7 @@ void CLikeSourceEmitter::emitGlobalVar(IRGlobalVar* varDecl)
emitVarModifiers(layout, varDecl, varType);
emitRateQualifiersAndAddressSpace(varDecl);
+ emitVarKeyword(varType, /* isConstant */ true);
emitType(varType, getName(varDecl));
// TODO: These shouldn't be needed for ordinary
@@ -4172,7 +4221,8 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl)
emitDecorationLayoutSemantics(varDecl, "register");
emitRateQualifiersAndAddressSpace(varDecl);
- emitType(varType, getName(varDecl));
+ emitVarKeyword(varType, /* isConstant */ false);
+ emitGlobalParamType(varType, getName(varDecl));
emitSemantics(varDecl);
diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h
index 00ad156d1..be769f31f 100644
--- a/source/slang/slang-emit-c-like.h
+++ b/source/slang/slang-emit-c-like.h
@@ -248,7 +248,8 @@ public:
//
void ensureTypePrelude(IRType* type);
- void emitDeclarator(DeclaratorInfo* declarator);
+ void emitDeclarator(DeclaratorInfo* declarator) {emitDeclaratorImpl(declarator);}
+ virtual void emitDeclaratorImpl(DeclaratorInfo* declarator);
void emitType(IRType* type, const StringSliceLoc* nameLoc) { emitTypeImpl(type, nameLoc); }
void emitType(IRType* type, Name* name);
@@ -256,6 +257,7 @@ public:
void emitType(IRType* type);
void emitType(IRType* type, Name* name, SourceLoc const& nameLoc);
void emitType(IRType* type, NameLoc const& nameAndLoc);
+ virtual void emitGlobalParamType(IRType* type, String const& name) {emitType(type, name);}
bool hasExplicitConstantBufferOffset(IRInst* cbufferType);
bool isSingleElementConstantBuffer(IRInst* cbufferType);
bool shouldForceUnpackConstantBufferElements(IRInst* cbufferType);
@@ -368,8 +370,11 @@ public:
/// Emit high-level statements for the body of a function.
void emitFunctionBody(IRGlobalValueWithCode* code);
+ void emitFuncHeader(IRFunc* func) { emitFuncHeaderImpl(func); }
void emitSimpleFunc(IRFunc* func) { emitSimpleFuncImpl(func); }
+ void emitSwitchCaseSelectors(IRBasicType *const switchConditionType, const SwitchRegion::Case *const currentCase, const bool isDefault) {emitSwitchCaseSelectorsImpl(switchConditionType, currentCase, isDefault);}
+
void emitParamType(IRType* type, String const& name) { emitParamTypeImpl(type, name); }
void emitFuncDecl(IRFunc* func);
@@ -394,10 +399,14 @@ public:
void emitStructDeclarationsBlock(IRStructType* structType, bool allowOffsetLayout);
void emitClass(IRClassType* structType);
+ void emitStructDeclarationSeparator() {emitStructDeclarationSeparatorImpl();}
+ virtual void emitStructDeclarationSeparatorImpl();
+
/// Emit type attributes that should appear after, e.g., a `struct` keyword
void emitPostKeywordTypeAttributes(IRInst* inst) { emitPostKeywordTypeAttributesImpl(inst); }
virtual void emitMemoryQualifiers(IRInst* /*varInst*/) {};
+ virtual void emitStructFieldAttributes(IRStructType * /* structType */, IRStructField * /* field */) {};
void emitInterpolationModifiers(IRInst* varInst, IRType* valueType, IRVarLayout* layout);
void emitMeshShaderModifiers(IRInst* varInst);
virtual void emitPackOffsetModifier(IRInst* /*varInst*/, IRType* /*valueType*/, IRPackOffsetDecoration* /*decoration*/) {};
@@ -421,6 +430,7 @@ public:
void emitGlobalInst(IRInst* inst);
virtual void emitGlobalInstImpl(IRInst* inst);
+ virtual bool isPointerSyntaxRequiredImpl(IRInst* inst);
void ensureInstOperand(ComputeEmitActionsContext* ctx, IRInst* inst, EmitAction::Level requiredLevel = EmitAction::Level::Definition);
@@ -486,6 +496,11 @@ public:
virtual void emitPreModuleImpl();
virtual void emitPostModuleImpl();
+ virtual void emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator);
+ void emitSimpleTypeAndDeclarator(IRType* type, DeclaratorInfo* declarator) {emitSimpleTypeAndDeclaratorImpl(type, declarator);};
+ virtual void emitVarKeywordImpl(IRType * type, bool isConstant);
+ void emitVarKeyword(IRType * type, bool isConstant) {emitVarKeywordImpl(type, isConstant);}
+
virtual void beforeComputeEmitActions(IRModule* module) { SLANG_UNUSED(module); };
virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace) { SLANG_UNUSED(rate); SLANG_UNUSED(addressSpace); }
@@ -501,6 +516,7 @@ public:
virtual void emitTypeImpl(IRType* type, const StringSliceLoc* nameLoc);
virtual void emitSimpleValueImpl(IRInst* inst);
virtual void emitModuleImpl(IRModule* module, DiagnosticSink* sink);
+ virtual void emitFuncHeaderImpl(IRFunc* func);
virtual void emitSimpleFuncImpl(IRFunc* func);
virtual void emitVarExpr(IRInst* inst, EmitOpInfo const& outerPrec);
virtual void emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPrec);
@@ -511,6 +527,7 @@ public:
virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) { SLANG_UNUSED(decl); }
virtual void emitIfDecorationsImpl(IRIfElse* ifInst) { SLANG_UNUSED(ifInst); }
virtual void emitSwitchDecorationsImpl(IRSwitch* switchInst) { SLANG_UNUSED(switchInst); }
+ virtual void emitSwitchCaseSelectorsImpl(IRBasicType *const switchConditionType, const SwitchRegion::Case *const currentCase, const bool isDefault);
virtual void emitFuncDecorationImpl(IRDecoration* decoration) { SLANG_UNUSED(decoration); }
virtual void emitLivenessImpl(IRInst* inst);
diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp
new file mode 100644
index 000000000..0a4cca407
--- /dev/null
+++ b/source/slang/slang-emit-wgsl.cpp
@@ -0,0 +1,1005 @@
+#include "slang-emit-wgsl.h"
+
+// A note on row/column "terminology reversal".
+//
+// This is an "terminology reversing" implementation in the sense that
+// * "column" in Slang code maps to "row" in the generated WGSL code, and
+// * "row" in Slang code maps to "column" in the generated WGSL code.
+//
+// This means that matrices in Slang code end up getting translated to
+// matrices that actually represent the transpose of what the Slang matrix
+// represented.
+// Both API's adopt the standard matrix multiplication convention whereby the
+// column count of the matrix on the left hand side needs to match row count of
+// the matrix on the right hand side.
+// For these reasons, and due to the fact that (M_1 ... M_n)^T = M_n^T ... M_1^T,
+// the order of matrix (and vector-matrix products) products must also reversed
+// in the WGSL code.
+//
+// This may lead to confusion (which is why this note is referenced in several
+// places), but the benefit of doing this is that the generated WGSL code is
+// simpler to generate and should be faster to compile.
+// A "terminology preserving" implementation would have to generate lots of
+// 'transpose' calls, or else perform more complicated transformations that
+// end up duplicating expressions many times.
+
+namespace Slang {
+
+void WGSLSourceEmitter::emitSwitchCaseSelectorsImpl(
+ IRBasicType *const switchConditionType,
+ const SwitchRegion::Case *const currentCase, const bool isDefault
+ )
+{
+ // WGSL has special syntax for blocks sharing case labels:
+ // "case 2, 3, 4: ...;" instead of the C-like syntax
+ // "case 2: case 3: case 4: ...;".
+
+ m_writer->emit("case ");
+ for (auto caseVal : currentCase->values)
+ {
+ // TODO: Fix this in the front-end [1], remove the if-path and just do the else-path.
+ // We can't do that at the moment because it would break Falcor [2].
+ // [1] https://github.com/shader-slang/slang/pull/5025/commits/a32156ef52f43b8503b2c77f2f1d51220ab9bdea
+ // [2] https://github.com/shader-slang/slang/pull/5025#issuecomment-2334495120
+ if (caseVal->getOp() == kIROp_IntLit)
+ {
+ auto caseLitInst = static_cast<IRConstant*>(caseVal);
+ IRBasicType *const caseInstType = as<IRBasicType>(caseLitInst->getDataType());
+ // WGSL doesn't allow switch condition and case type mismatches, see [1].
+ // Thus we need to insert explicit conversions.
+ // Doing a wrapping cast will match Slang's de facto semantics, according to
+ // [2].
+ // (This is just a bitcast, assuming a two's complement representation.)
+ // [1] https://www.w3.org/TR/WGSL/#switch-statement
+ // [2] https://github.com/shader-slang/slang/issues/4921
+ const bool needBitcast =
+ caseInstType->getBaseType() != switchConditionType->getBaseType();
+ if (needBitcast)
+ {
+ m_writer->emit("bitcast<");
+ emitType(switchConditionType);
+ m_writer->emit(">(");
+ }
+ emitOperand(caseVal, getInfo(EmitOp::General));
+ if (needBitcast)
+ {
+ m_writer->emit(")");
+ }
+ }
+ else
+ {
+ emitOperand(caseVal, getInfo(EmitOp::General));
+ }
+ m_writer->emit(", ");
+ }
+ if (isDefault)
+ {
+ m_writer->emit("default, ");
+ }
+ m_writer->emit(":\n");
+}
+
+void WGSLSourceEmitter::emitParameterGroupImpl(
+ IRGlobalParam* varDecl, IRUniformParameterGroupType* type
+)
+{
+ auto varLayout = getVarLayout(varDecl);
+ SLANG_RELEASE_ASSERT(varLayout);
+
+ for (auto attr : varLayout->getOffsetAttrs())
+ {
+
+ const LayoutResourceKind kind = attr->getResourceKind();
+ switch (kind)
+ {
+ case LayoutResourceKind::VaryingInput:
+ case LayoutResourceKind::VaryingOutput:
+ m_writer->emit("@location(");
+ m_writer->emit(attr->getOffset());
+ m_writer->emit(")");
+ if (attr->getSpace())
+ {
+ // TODO: Not sure what 'space' should map to in WGSL
+ SLANG_ASSERT(false);
+ }
+ break;
+
+ case LayoutResourceKind::SpecializationConstant:
+ // TODO:
+ // Consider moving to a differently named function.
+ // This is not technically an attribute, but a declaration.
+ //
+ // https://www.w3.org/TR/WGSL/#override-decls
+ m_writer->emit("override");
+ break;
+
+ case LayoutResourceKind::Uniform:
+ case LayoutResourceKind::ConstantBuffer:
+ case LayoutResourceKind::ShaderResource:
+ case LayoutResourceKind::UnorderedAccess:
+ case LayoutResourceKind::SamplerState:
+ case LayoutResourceKind::DescriptorTableSlot:
+ m_writer->emit("@binding(");
+ m_writer->emit(attr->getOffset());
+ m_writer->emit(") ");
+ m_writer->emit("@group(");
+ m_writer->emit(attr->getSpace());
+ m_writer->emit(") ");
+ break;
+
+ }
+
+ }
+
+ auto elementType = type->getElementType();
+ m_writer->emit("var<uniform> ");
+ m_writer->emit(getName(varDecl));
+ m_writer->emit(" : ");
+ emitType(elementType);
+ m_writer->emit(";\n");
+}
+
+void WGSLSourceEmitter::emitEntryPointAttributesImpl(
+ IRFunc* irFunc, IREntryPointDecoration* entryPointDecor
+ )
+{
+ auto stage = entryPointDecor->getProfile().getStage();
+
+ switch (stage)
+ {
+
+ case Stage::Fragment:
+ m_writer->emit("@fragment\n");
+ break;
+ case Stage::Vertex:
+ m_writer->emit("@vertex\n");
+ break;
+
+ case Stage::Compute:
+ {
+ m_writer->emit("@compute\n");
+
+ {
+ Int sizeAlongAxis[kThreadGroupAxisCount];
+ getComputeThreadGroupSize(irFunc, sizeAlongAxis);
+
+ m_writer->emit("@workgroup_size(");
+ for (int ii = 0; ii < kThreadGroupAxisCount; ++ii)
+ {
+ if (ii != 0)
+ m_writer->emit(", ");
+ m_writer->emit(sizeAlongAxis[ii]);
+ }
+ m_writer->emit(")\n");
+ }
+ }
+ break;
+
+ default:
+ SLANG_ABORT_COMPILATION("unsupported stage.");
+ }
+
+}
+
+// This is 'function_header' from the WGSL specification
+void WGSLSourceEmitter::emitFuncHeaderImpl(IRFunc* func)
+{
+ Slang::IRType * resultType = func->getResultType();
+ auto name = getName(func);
+
+ m_writer->emit("fn ");
+ m_writer->emit(name);
+
+ emitSimpleFuncParamsImpl(func);
+
+ // An absence of return type is expressed by skipping the optional '->' part of the
+ // header.
+ if (resultType->getOp() != kIROp_VoidType)
+ {
+ m_writer->emit(" -> ");
+ emitType(resultType);
+ }
+}
+
+void WGSLSourceEmitter::emitSimpleFuncParamImpl(IRParam* param)
+{
+ if (auto sysSemanticDecor = param->findDecoration<IRTargetSystemValueDecoration>())
+ {
+ m_writer->emit("@builtin(");
+ m_writer->emit(sysSemanticDecor->getSemantic());
+ m_writer->emit(")");
+ }
+
+ CLikeSourceEmitter::emitSimpleFuncParamImpl(param);
+}
+
+void WGSLSourceEmitter::emitMatrixType(
+ IRType *const elementType, const IRIntegerValue& rowCountWGSL,
+ const IRIntegerValue& colCountWGSL
+ )
+{
+ // WGSL uses CxR convention
+ m_writer->emit("mat");
+ m_writer->emit(colCountWGSL);
+ m_writer->emit("x");
+ m_writer->emit(rowCountWGSL);
+ m_writer->emit("<");
+ emitType(elementType);
+ m_writer->emit(">");
+}
+
+void WGSLSourceEmitter::emitStructDeclarationSeparatorImpl()
+{
+ m_writer->emit(",");
+}
+
+static bool isPowerOf2(const uint32_t n)
+{
+ return (n != 0U) && ((n - 1U) & n) == 0U;
+}
+
+void WGSLSourceEmitter::emitStructFieldAttributes(
+ IRStructType * structType, IRStructField * field
+ )
+{
+ // Tint emits errors unless we explicitly spell out the layout in some cases, so emit
+ // offset and align attribtues for all fields.
+ IRSizeAndAlignmentDecoration *const sizeAndAlignmentDecoration =
+ structType->findDecoration<IRSizeAndAlignmentDecoration>();
+ // NullDifferential struct doesn't have size and alignment decoration
+ if (sizeAndAlignmentDecoration == nullptr)
+ return;
+ SLANG_ASSERT(sizeAndAlignmentDecoration->getAlignment() > IRIntegerValue{0});
+ SLANG_ASSERT(
+ sizeAndAlignmentDecoration->getAlignment() <= IRIntegerValue{UINT32_MAX}
+ );
+ const uint32_t structAlignment =
+ static_cast<uint32_t>(sizeAndAlignmentDecoration->getAlignment());
+ IROffsetDecoration *const fieldOffsetDecoration =
+ field->findDecoration<IROffsetDecoration>();
+ SLANG_ASSERT(fieldOffsetDecoration->getOffset() >= IRIntegerValue{0});
+ SLANG_ASSERT(fieldOffsetDecoration->getOffset() <= IRIntegerValue{UINT32_MAX});
+ SLANG_ASSERT(isPowerOf2(structAlignment));
+ const uint32_t fieldOffset =
+ static_cast<uint32_t>(fieldOffsetDecoration->getOffset());
+ // Alignment is GCD(fieldOffset, structAlignment)
+ // TODO: Use builtin/intrinsic (e.g. __builtin_ffs)
+ uint32_t fieldAlignment = 1U;
+ while (((fieldAlignment & (structAlignment | fieldOffset)) == 0U))
+ fieldAlignment = fieldAlignment << 1U;
+
+ m_writer->emit("@align(");
+ m_writer->emit(fieldAlignment);
+ m_writer->emit(")");
+}
+
+bool WGSLSourceEmitter::isPointerSyntaxRequiredImpl(IRInst* inst)
+{
+ // Structured buffers are mapped to 'array' types, which don't need dereferencing
+ if (inst->getOp() == kIROp_RWStructuredBufferGetElementPtr)
+ return false;
+
+ // Don't emit "->" to access fields in resource structs
+ if (inst->getOp() == kIROp_FieldAddress)
+ return false;
+
+ // Don't emit "*" to access fields in resource structs
+ if (inst->getOp() == kIROp_GlobalParam)
+ return false;
+
+ // Emit 'globalVar' instead of "*&globalVar"
+ if (inst->getOp() == kIROp_GlobalVar)
+ return false;
+
+ return true;
+}
+
+void WGSLSourceEmitter::emit(const AddressSpace addressSpace)
+{
+ switch (addressSpace)
+ {
+ case AddressSpace::Uniform:
+ m_writer->emit("uniform");
+ break;
+
+ case AddressSpace::StorageBuffer:
+ m_writer->emit("storage");
+ break;
+
+ case AddressSpace::Generic:
+ m_writer->emit("function");
+ break;
+
+ case AddressSpace::ThreadLocal:
+ m_writer->emit("private");
+ break;
+
+ case AddressSpace::GroupShared:
+ m_writer->emit("workgroup");
+ break;
+ }
+}
+
+void WGSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
+{
+ switch (type->getOp())
+ {
+
+ case kIROp_HLSLRWStructuredBufferType:
+ {
+ auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type);
+ m_writer->emit("ptr<");
+ emit(AddressSpace::StorageBuffer);
+ m_writer->emit(", ");
+ m_writer->emit("array");
+ m_writer->emit("<");
+ emitType(structuredBufferType->getElementType());
+ m_writer->emit(">");
+ m_writer->emit(", read_write");
+ m_writer->emit(">");
+ }
+ break;
+
+ case kIROp_HLSLStructuredBufferType:
+ {
+ auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type);
+ m_writer->emit("ptr<");
+ emit(AddressSpace::StorageBuffer);
+ m_writer->emit(", ");
+ m_writer->emit("array");
+ m_writer->emit("<");
+ emitType(structuredBufferType->getElementType());
+ m_writer->emit(">");
+ m_writer->emit(", read");
+ m_writer->emit(">");
+ }
+ break;
+
+ case kIROp_VoidType:
+ {
+ // There is no void type in WGSL.
+ // A return type of "void" is expressed by skipping the end part of the
+ // 'function_header' term:
+ // "
+ // function_header :
+ // 'fn' ident '(' param_list ? ')'
+ // ( '->' attribute * template_elaborated_ident ) ?
+ // "
+ // In other words, in WGSL we should never even get to the point where we're
+ // asking to emit 'void'.
+ SLANG_UNEXPECTED("'void' type emitted");
+ return;
+ }
+
+ case kIROp_FloatType:
+ m_writer->emit("f32");
+ break;
+ case kIROp_DoubleType:
+ // There is no "f64" type in WGSL
+ SLANG_UNEXPECTED("'double' type emitted");
+ break;
+ case kIROp_Int8Type:
+ case kIROp_UInt8Type:
+ // There is no "[i|u]8" type in WGSL
+ SLANG_UNEXPECTED("8 bit integer type emitted");
+ break;
+ case kIROp_HalfType:
+ m_f16ExtensionEnabled = true;
+ m_writer->emit("f16");
+ break;
+ case kIROp_BoolType:
+ m_writer->emit("bool");
+ break;
+ case kIROp_IntType:
+ m_writer->emit("i32");
+ break;
+ case kIROp_UIntType:
+ m_writer->emit("u32");
+ break;
+ case kIROp_UInt64Type:
+ {
+ m_writer->emit(getDefaultBuiltinTypeName(type->getOp()));
+ return;
+ }
+ case kIROp_Int16Type:
+ case kIROp_UInt16Type:
+ SLANG_UNEXPECTED("16 bit integer value emitted");
+ return;
+ case kIROp_Int64Type:
+ case kIROp_IntPtrType:
+ m_writer->emit("i64");
+ return;
+ case kIROp_UIntPtrType:
+ m_writer->emit("u64");
+ return;
+ case kIROp_StructType:
+ m_writer->emit(getName(type));
+ return;
+
+ case kIROp_VectorType:
+ {
+ auto vecType = (IRVectorType*)type;
+ emitVectorTypeNameImpl(
+ vecType->getElementType(), getIntVal(vecType->getElementCount())
+ );
+ return;
+ }
+ case kIROp_MatrixType:
+ {
+ auto matType = (IRMatrixType*)type;
+ // We map matrices in Slang to WGSL matrices that represent the transpose.
+ // (See note on "terminology reversal".)
+ const IRIntegerValue colCountWGSL = getIntVal(matType->getRowCount());
+ const IRIntegerValue rowCountWGSL = getIntVal(matType->getColumnCount());
+ emitMatrixType(matType->getElementType(), rowCountWGSL, colCountWGSL);
+ return;
+ }
+ case kIROp_SamplerStateType:
+ {
+ m_writer->emit("sampler");
+ return;
+ }
+
+ case kIROp_SamplerComparisonStateType:
+ {
+ m_writer->emit("sampler_comparison");
+ return;
+ }
+
+ case kIROp_PtrType:
+ case kIROp_InOutType:
+ case kIROp_OutType:
+ case kIROp_RefType:
+ case kIROp_ConstRefType:
+ {
+ auto ptrType = cast<IRPtrTypeBase>(type);
+ m_writer->emit("ptr<");
+ emit((AddressSpace)ptrType->getAddressSpace());
+ m_writer->emit(", ");
+ emitType((IRType*)ptrType->getValueType());
+ m_writer->emit(">");
+ return;
+ }
+
+ case kIROp_ArrayType:
+ {
+ m_writer->emit("array<");
+ emitType((IRType*)type->getOperand(0));
+ m_writer->emit(", ");
+ emitVal(type->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(">");
+ return;
+ }
+ default:
+ break;
+
+ }
+
+}
+
+void WGSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout)
+{
+
+ for (auto attr : layout->getOffsetAttrs())
+ {
+ LayoutResourceKind kind = attr->getResourceKind();
+
+ // TODO:
+ // This is not correct. For the moment this is just here as a hack to make
+ // @binding and @group unique, so that we can pass WGSL compile tests.
+ // This will have to be revisited when we actually want to supply resources to
+ // shaders.
+ if (kind == LayoutResourceKind::DescriptorTableSlot)
+ {
+ m_writer->emit("@binding(");
+ m_writer->emit(attr->getOffset());
+ m_writer->emit(") ");
+ m_writer->emit("@group(");
+ m_writer->emit(attr->getSpace());
+ m_writer->emit(") ");
+
+ return;
+ }
+ }
+
+}
+
+void WGSLSourceEmitter::emitVarKeywordImpl(IRType * type, const bool isConstant)
+{
+ if (isConstant)
+ m_writer->emit("const");
+ else
+ m_writer->emit("var");
+ if (type->getOp() == kIROp_HLSLRWStructuredBufferType)
+ {
+ m_writer->emit("<");
+ m_writer->emit("storage, read_write");
+ m_writer->emit(">");
+ }
+ else if (type->getOp() == kIROp_HLSLStructuredBufferType)
+ {
+ m_writer->emit("<");
+ m_writer->emit("storage, read");
+ m_writer->emit(">");
+ }
+}
+
+void WGSLSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator)
+{
+ // C-like languages bake array-ness, pointer-ness and reference-ness into the
+ // declarator, which happens in the default _emitType implementation.
+ // WGSL on the other hand, don't have special syntax -- these are just types.
+ switch (type->getOp())
+ {
+ case kIROp_ArrayType:
+ case kIROp_AttributedType:
+ case kIROp_UnsizedArrayType:
+ emitSimpleTypeAndDeclarator(type, declarator);
+ break;
+ default:
+ CLikeSourceEmitter::_emitType(type, declarator);
+ break;
+ }
+}
+
+void WGSLSourceEmitter::emitDeclaratorImpl(DeclaratorInfo* declarator)
+{
+ if (!declarator) return;
+
+ m_writer->emit(" ");
+
+ switch (declarator->flavor)
+ {
+ case DeclaratorInfo::Flavor::Name:
+ {
+ auto nameDeclarator = (NameDeclaratorInfo*)declarator;
+ m_writer->emitName(*nameDeclarator->nameAndLoc);
+ }
+ break;
+
+ case DeclaratorInfo::Flavor::SizedArray:
+ {
+ // Sized arrays are just types (array<T, N>) in WGSL -- they are not
+ // supported at the syntax level
+ // https://www.w3.org/TR/WGSL/#array
+ SLANG_UNEXPECTED("Sized array declarator");
+ }
+ break;
+
+ case DeclaratorInfo::Flavor::UnsizedArray:
+ {
+ // Unsized arrays are just types (array<T>) in WGSL -- they are not
+ // supported at the syntax level
+ // https://www.w3.org/TR/WGSL/#array
+ SLANG_UNEXPECTED("Unsized array declarator");
+ }
+ break;
+
+ case DeclaratorInfo::Flavor::Ptr:
+ {
+ // Pointers (ptr<AS,T,AM>) are just types in WGSL -- they are not supported at
+ // the syntax level
+ // https://www.w3.org/TR/WGSL/#ref-ptr-types
+ SLANG_UNEXPECTED("Pointer declarator");
+ }
+ break;
+
+ case DeclaratorInfo::Flavor::Ref:
+ {
+ // References (ref<AS,T,AM>) are just types in WGSL -- they are not supported
+ // at the syntax level
+ // https://www.w3.org/TR/WGSL/#ref-ptr-types
+ SLANG_UNEXPECTED("Reference declarator");
+ }
+ break;
+
+ case DeclaratorInfo::Flavor::LiteralSizedArray:
+ {
+ // Sized arrays are just types (array<T, N>) in WGSL -- they are not supported
+ // at the syntax level
+ // https://www.w3.org/TR/WGSL/#array
+ SLANG_UNEXPECTED("Literal-sized array declarator");
+ }
+ break;
+
+ case DeclaratorInfo::Flavor::Attributed:
+ {
+ auto attributedDeclarator = (AttributedDeclaratorInfo*)declarator;
+ auto instWithAttributes = attributedDeclarator->instWithAttributes;
+ for (auto attr : instWithAttributes->getAllAttrs())
+ {
+ _emitPostfixTypeAttr(attr);
+ }
+ emitDeclarator(attributedDeclarator->next);
+ }
+ break;
+
+ default:
+ SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unknown declarator flavor");
+ break;
+ }
+}
+
+void WGSLSourceEmitter::emitSimpleTypeAndDeclaratorImpl(
+ IRType* type, DeclaratorInfo* declarator
+ )
+{
+ if (declarator)
+ {
+ emitDeclarator(declarator);
+ m_writer->emit(" : ");
+ }
+ emitSimpleType(type);
+}
+
+void WGSLSourceEmitter::emitSimpleValueImpl(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_IntLit:
+ {
+ auto litInst = static_cast<IRConstant*>(inst);
+
+ IRBasicType* type = as<IRBasicType>(inst->getDataType());
+ if (type)
+ {
+ switch (type->getBaseType())
+ {
+ default:
+
+ case BaseType::Int8:
+ case BaseType::UInt8:
+ {
+ SLANG_UNEXPECTED("8 bit integer value emitted");
+ break;
+ }
+ case BaseType::Int16:
+ case BaseType::UInt16:
+ {
+ SLANG_UNEXPECTED("16 bit integer value emitted");
+ break;
+ }
+ case BaseType::Int:
+ {
+ m_writer->emit("i32(");
+ m_writer->emit(int32_t(litInst->value.intVal));
+ m_writer->emit(")");
+ return;
+ }
+ case BaseType::UInt:
+ {
+ m_writer->emit("u32(");
+ m_writer->emit(UInt(uint32_t(litInst->value.intVal)));
+ m_writer->emit(")");
+ break;
+ }
+ case BaseType::Int64:
+ {
+ m_writer->emit("i64(");
+ m_writer->emitInt64(int64_t(litInst->value.intVal));
+ m_writer->emit(")");
+ break;
+ }
+ case BaseType::UInt64:
+ {
+ m_writer->emit("u64(");
+ SLANG_COMPILE_TIME_ASSERT(
+ sizeof(litInst->value.intVal) >= sizeof(uint64_t)
+ );
+ m_writer->emitUInt64(uint64_t(litInst->value.intVal));
+ m_writer->emit(")");
+ break;
+ }
+ case BaseType::IntPtr:
+ {
+#if SLANG_PTR_IS_64
+ m_writer->emit("i64(");
+ m_writer->emitInt64(int64_t(litInst->value.intVal));
+ m_writer->emit(")");
+#else
+ m_writer->emit("i32(");
+ m_writer->emit(int(litInst->value.intVal));
+ m_writer->emit(")");
+#endif
+ break;
+ }
+ case BaseType::UIntPtr:
+ {
+#if SLANG_PTR_IS_64
+ m_writer->emit("u64(");
+ m_writer->emitUInt64(uint64_t(litInst->value.intVal));
+ m_writer->emit(")");
+#else
+ m_writer->emit("u32(");
+ m_writer->emit(UInt(uint32_t(litInst->value.intVal)));
+ m_writer->emit(")");
+#endif
+ break;
+ }
+
+ }
+ }
+ else
+ {
+ // If no type... just output what we have
+ m_writer->emit(litInst->value.intVal);
+ }
+ break;
+ }
+
+ case kIROp_FloatLit:
+ {
+ auto litInst = static_cast<IRConstant*>(inst);
+
+ IRBasicType* type = as<IRBasicType>(inst->getDataType());
+ if (type)
+ {
+ switch (type->getBaseType())
+ {
+ default:
+
+ case BaseType::Half:
+ {
+ m_writer->emit(litInst->value.floatVal);
+ m_writer->emit("h");
+ m_f16ExtensionEnabled = true;
+ }
+ break;
+
+ case BaseType::Float:
+ {
+ m_writer->emit(litInst->value.floatVal);
+ m_writer->emit("f");
+ }
+ break;
+
+ case BaseType::Double:
+ {
+ // There is not "f64" in WGSL
+ SLANG_UNEXPECTED("'double' type emitted");
+ }
+ break;
+ }
+ }
+ else
+ {
+ // If no type... just output what we have
+ m_writer->emit(litInst->value.floatVal);
+ }
+ }
+ break;
+
+ case kIROp_BoolLit:
+ {
+ bool val = ((IRConstant*)inst)->value.intVal != 0;
+ m_writer->emit(val ? "true" : "false");
+ }
+ break;
+
+ default:
+ SLANG_UNIMPLEMENTED_X("val case for emit");
+ break;
+ }
+
+
+}
+
+void WGSLSourceEmitter::emitParamTypeImpl(IRType* type, const String& name)
+{
+ emitType(type, name);
+}
+
+bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
+{
+ EmitOpInfo outerPrec = inOuterPrec;
+
+ switch (inst->getOp())
+ {
+
+ case kIROp_MakeVectorFromScalar:
+ {
+ // In WGSL this is done by calling the vec* overloads listed in [1]
+ // [1] https://www.w3.org/TR/WGSL/#value-constructor-builtin-function
+ emitType(inst->getDataType());
+ m_writer->emit("(");
+ auto prec = getInfo(EmitOp::Prefix);
+ emitOperand(inst->getOperand(0), rightSide(outerPrec, prec));
+ m_writer->emit(")");
+ return true;
+ }
+ break;
+
+ case kIROp_BitCast:
+ {
+ // In WGSL there is a built-in bitcast function!
+ // https://www.w3.org/TR/WGSL/#bitcast-builtin
+ m_writer->emit("bitcast");
+ m_writer->emit("<");
+ emitType(inst->getDataType());
+ m_writer->emit(">");
+ m_writer->emit("(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
+ break;
+
+ case kIROp_MakeArray:
+ case kIROp_MakeStruct:
+ {
+ // It seems there are currently no designated initializers in WGSL.
+ // Similarly for array initializers.
+ // https://github.com/gpuweb/gpuweb/issues/4210
+
+ // There is a constructor named like the struct/array type itself
+ auto type = inst->getDataType();
+ emitType(type);
+ m_writer->emit("( ");
+ UInt argCount = inst->getOperandCount();
+ for (UInt aa = 0; aa < argCount; ++aa)
+ {
+ if (aa != 0) m_writer->emit(", ");
+ emitOperand(inst->getOperand(aa), getInfo(EmitOp::General));
+ }
+ m_writer->emit(" )");
+
+ return true;
+ }
+ break;
+
+ case kIROp_MakeArrayFromElement:
+ {
+ // It seems there are currently no array initializers in WGSL.
+
+ // There is a constructor named like the array type itself
+ auto type = inst->getDataType();
+ emitType(type);
+ m_writer->emit("(");
+ UInt argCount =
+ (UInt)cast<IRIntLit>(
+ cast<IRArrayType>(inst->getDataType())->getElementCount()
+ )->getValue();
+ for (UInt aa = 0; aa < argCount; ++aa)
+ {
+ if (aa != 0) m_writer->emit(", ");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ }
+ m_writer->emit(")");
+ return true;
+ }
+ break;
+
+ case kIROp_StructuredBufferLoad:
+ case kIROp_RWStructuredBufferLoad:
+ {
+ // Structured buffers are just arrays in WGSL
+ auto base = inst->getOperand(0);
+ emitOperand(base, outerPrec);
+ m_writer->emit("[");
+ emitOperand(inst->getOperand(1), EmitOpInfo());
+ m_writer->emit("]");
+ return true;
+ }
+ break;
+
+ case kIROp_Rsh:
+ case kIROp_Lsh:
+ {
+ // Shift amounts must be an unsigned type in WGSL
+ // https://www.w3.org/TR/WGSL/#bit-expr
+ IRInst *const shiftAmount = inst->getOperand(1);
+ IRType *const shiftAmountType = shiftAmount->getDataType();
+ if (shiftAmountType->getOp() == kIROp_IntType)
+ {
+ // Dawn complains about "mixing '<<' and '|' requires parenthesis", so let's
+ // add parenthesis.
+ m_writer->emit("(");
+
+ const auto emitOp = getEmitOpForOp(inst->getOp());
+ const auto info = getInfo(emitOp);
+
+ const bool needClose = maybeEmitParens(outerPrec, info);
+ emitOperand(inst->getOperand(0), leftSide(outerPrec, info));
+ m_writer->emit(" ");
+ m_writer->emit(info.op);
+ m_writer->emit(" ");
+ m_writer->emit("bitcast<u32>(");
+ emitOperand(inst->getOperand(1), rightSide(outerPrec, info));
+ m_writer->emit(")");
+ maybeCloseParens(needClose);
+
+ m_writer->emit(")");
+ return true;
+ }
+ }
+ break;
+
+ }
+
+ return false;
+}
+
+void WGSLSourceEmitter::emitVectorTypeNameImpl(
+ IRType* elementType, IRIntegerValue elementCount
+ )
+{
+
+ if (elementCount > 1)
+ {
+ m_writer->emit("vec");
+ m_writer->emit(elementCount);
+ m_writer->emit("<");
+ emitSimpleType(elementType);
+ m_writer->emit(">");
+ }
+ else
+ {
+ emitSimpleType(elementType);
+ }
+}
+
+void WGSLSourceEmitter::emitOperandImpl(IRInst* inst, const EmitOpInfo& outerPrec)
+{
+ // In WGSL, the structured buffer types are converted to ptr<AS, array<E>, AM>
+ // everywhere, except for the global parameter declaration.
+ // Thus, when these globals are used in expressions, we need an ampersand.
+
+ if (inst->getOp() == kIROp_GlobalParam)
+ {
+ switch (inst->getDataType()->getOp())
+ {
+ case kIROp_HLSLStructuredBufferType:
+ case kIROp_HLSLRWStructuredBufferType:
+
+ m_writer->emit("(&");
+ CLikeSourceEmitter::emitOperandImpl(inst, outerPrec);
+ m_writer->emit(")");
+ return;
+ }
+ }
+
+ CLikeSourceEmitter::emitOperandImpl(inst, outerPrec);
+}
+
+void WGSLSourceEmitter::emitGlobalParamType(IRType* type, const String& name)
+{
+ // In WGSL, the structured buffer types are converted to ptr<AS, array<E>, AM>
+ // everywhere, except for the global parameter declaration.
+
+ switch (type->getOp())
+ {
+
+ case kIROp_HLSLStructuredBufferType:
+ case kIROp_HLSLRWStructuredBufferType:
+ {
+ StringSliceLoc nameAndLoc(name.getUnownedSlice());
+ NameDeclaratorInfo nameDeclarator(&nameAndLoc);
+ emitDeclarator(&nameDeclarator);
+ m_writer->emit(" : ");
+ auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type);
+ m_writer->emit("array");
+ m_writer->emit("<");
+ emitType(structuredBufferType->getElementType());
+ m_writer->emit(">");
+ }
+ break;
+
+ default:
+
+ emitType(type, name);
+ break;
+
+ }
+
+}
+
+void WGSLSourceEmitter::emitFrontMatterImpl(TargetRequest* /* targetReq */)
+{
+ if (m_f16ExtensionEnabled)
+ {
+ m_writer->emit("enable f16;\n");
+ m_writer->emit("\n");
+ }
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h
new file mode 100644
index 000000000..dacd11c3d
--- /dev/null
+++ b/source/slang/slang-emit-wgsl.h
@@ -0,0 +1,71 @@
+#pragma once
+
+#include "slang-emit-c-like.h"
+
+namespace Slang
+{
+
+class WGSLSourceEmitter : public CLikeSourceEmitter
+{
+public:
+
+ WGSLSourceEmitter(const Desc& desc)
+ : CLikeSourceEmitter(desc)
+ {}
+
+ virtual void emitParameterGroupImpl(
+ IRGlobalParam* varDecl, IRUniformParameterGroupType* type
+ ) SLANG_OVERRIDE;
+ virtual void emitEntryPointAttributesImpl(
+ IRFunc* irFunc, IREntryPointDecoration* entryPointDecor
+ ) SLANG_OVERRIDE;
+ virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE;
+ virtual void emitVectorTypeNameImpl(
+ IRType* elementType, IRIntegerValue elementCount
+ ) SLANG_OVERRIDE;
+ virtual void emitFuncHeaderImpl(IRFunc* func) SLANG_OVERRIDE;
+ virtual void emitSimpleValueImpl(IRInst* inst) SLANG_OVERRIDE;
+ virtual bool tryEmitInstExprImpl(
+ IRInst* inst, const EmitOpInfo& inOuterPrec
+ ) SLANG_OVERRIDE;
+ virtual void emitSwitchCaseSelectorsImpl(
+ IRBasicType *const switchCondition,
+ const SwitchRegion::Case *const currentCase,
+ const bool isDefault
+ ) SLANG_OVERRIDE;
+ virtual void emitSimpleTypeAndDeclaratorImpl(
+ IRType* type, DeclaratorInfo* declarator
+ ) SLANG_OVERRIDE;
+ virtual void emitVarKeywordImpl(IRType * type, const bool isConstant) SLANG_OVERRIDE;
+ virtual void emitDeclaratorImpl(DeclaratorInfo* declarator) SLANG_OVERRIDE;
+ virtual void emitStructDeclarationSeparatorImpl() SLANG_OVERRIDE;
+ virtual void emitLayoutQualifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE;
+ virtual void emitSimpleFuncParamImpl(IRParam* param) SLANG_OVERRIDE;
+ virtual void emitParamTypeImpl(IRType* type, const String& name) SLANG_OVERRIDE;
+ virtual bool isPointerSyntaxRequiredImpl(IRInst* inst) SLANG_OVERRIDE;
+ virtual void _emitType(IRType* type, DeclaratorInfo* declarator) SLANG_OVERRIDE;
+ virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE;
+ virtual void emitStructFieldAttributes(
+ IRStructType * structType, IRStructField * field
+ ) SLANG_OVERRIDE;
+ virtual void emitGlobalParamType(IRType* type, const String& name) SLANG_OVERRIDE;
+ virtual void emitOperandImpl(
+ IRInst* inst, const EmitOpInfo& outerPrec
+ ) SLANG_OVERRIDE;
+
+ void emit(const AddressSpace addressSpace);
+
+private:
+
+ // Emit the matrix type with 'rowCountWGSL' WGSL-rows and 'colCountWGSL' WGSL-columns
+ void emitMatrixType(
+ IRType *const elementType,
+ const IRIntegerValue& rowCountWGSL,
+ const IRIntegerValue& colCountWGSL
+ );
+
+ bool m_f16ExtensionEnabled {false};
+
+};
+
+} // namespace Slang
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index ed9e90462..2ccf075f3 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -31,6 +31,7 @@
#include "slang-ir-glsl-legalize.h"
#include "slang-ir-hlsl-legalize.h"
#include "slang-ir-metal-legalize.h"
+#include "slang-ir-wgsl-legalize.h"
#include "slang-ir-insts.h"
#include "slang-ir-inline.h"
#include "slang-ir-legalize-array-return-type.h"
@@ -101,6 +102,7 @@
#include "slang-emit-glsl.h"
#include "slang-emit-hlsl.h"
#include "slang-emit-metal.h"
+#include "slang-emit-wgsl.h"
#include "slang-emit-cpp.h"
#include "slang-emit-cuda.h"
#include "slang-emit-torch.h"
@@ -1234,6 +1236,12 @@ Result linkAndOptimizeIR(
}
break;
+ case CodeGenTarget::WGSL:
+ {
+ legalizeIRForWGSL(irModule, sink);
+ }
+ break;
+
default:
break;
}
@@ -1535,15 +1543,28 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr
auto targetProgram = getTargetProgram();
auto lineDirectiveMode = targetProgram->getOptionSet().getEnumOption<LineDirectiveMode>(CompilerOptionName::LineDirectiveMode);
- // To try to make the default behavior reasonable, we will
- // always use C-style line directives (to give the user
- // good source locations on error messages from downstream
- // compilers) *unless* they requested raw GLSL as the
- // output (in which case we want to maximize compatibility
- // with downstream tools).
- if (lineDirectiveMode == LineDirectiveMode::Default && targetRequest->getTarget() == CodeGenTarget::GLSL)
+ // We will generally use C-style line directives in order to give the user good
+ // source locations on error messages from downstream compilers, but there are
+ // a few exceptions.
+ if (lineDirectiveMode == LineDirectiveMode::Default)
{
- lineDirectiveMode = LineDirectiveMode::GLSL;
+
+ switch(targetRequest->getTarget())
+ {
+
+ case CodeGenTarget::GLSL:
+ // We want to maximize compatibility with downstream tools.
+ lineDirectiveMode = LineDirectiveMode::GLSL;
+ break;
+
+ case CodeGenTarget::WGSL:
+ // WGSL doesn't support line directives.
+ // See https://github.com/gpuweb/gpuweb/issues/606.
+ lineDirectiveMode = LineDirectiveMode::None;
+ break;
+
+ }
+
}
ComPtr<IBoxValue<SourceMap>> sourceMap;
@@ -1610,6 +1631,11 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr
sourceEmitter = new MetalSourceEmitter(desc);
break;
}
+ case SourceLanguage::WGSL:
+ {
+ sourceEmitter = new WGSLSourceEmitter(desc);
+ break;
+ }
default: break;
}
break;
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 5865d5320..01b1c20de 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -1511,6 +1511,7 @@ static bool doesTargetAllowUnresolvedFuncSymbol(TargetRequest* req)
case CodeGenTarget::Metal:
case CodeGenTarget::MetalLib:
case CodeGenTarget::MetalLibAssembly:
+ case CodeGenTarget::WGSL:
case CodeGenTarget::DXIL:
case CodeGenTarget::DXILAssembly:
case CodeGenTarget::HostCPPSource:
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index a480ae673..d0ad7483a 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -888,16 +888,19 @@ namespace Slang
IRTypeLayoutRules* getTypeLayoutRuleForBuffer(TargetProgram* target, IRType* bufferType)
{
- if (!isKhronosTarget(target->getTargetReq()))
- return IRTypeLayoutRules::getNatural();
+ if (target->getTargetReq()->getTarget() != CodeGenTarget::WGSL)
+ {
+ if (!isKhronosTarget(target->getTargetReq()))
+ return IRTypeLayoutRules::getNatural();
- // If we are just emitting GLSL, we can just use the general layout rule.
- if (!target->shouldEmitSPIRVDirectly())
- return IRTypeLayoutRules::getNatural();
+ // If we are just emitting GLSL, we can just use the general layout rule.
+ if (!target->shouldEmitSPIRVDirectly())
+ return IRTypeLayoutRules::getNatural();
- // If the user specified a scalar buffer layout, then just use that.
- if (target->getOptionSet().shouldUseScalarLayout())
- return IRTypeLayoutRules::getNatural();
+ // If the user specified a scalar buffer layout, then just use that.
+ if (target->getOptionSet().shouldUseScalarLayout())
+ return IRTypeLayoutRules::getNatural();
+ }
if (target->getOptionSet().shouldUseDXLayout())
{
diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp
new file mode 100644
index 000000000..e05eba78c
--- /dev/null
+++ b/source/slang/slang-ir-wgsl-legalize.cpp
@@ -0,0 +1,347 @@
+#include "slang-ir-wgsl-legalize.h"
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-parameter-binding.h"
+#include "slang-ir-legalize-varying-params.h"
+
+namespace Slang
+{
+
+ struct EntryPointInfo
+ {
+ IRFunc* entryPointFunc;
+ IREntryPointDecoration* entryPointDecor;
+ };
+
+ struct SystemValLegalizationWorkItem
+ {
+ IRInst* var;
+ String attrName;
+ UInt attrIndex;
+ };
+
+ struct WGSLSystemValueInfo
+ {
+ String wgslSystemValueName;
+ SystemValueSemanticName wgslSystemValueNameEnum;
+ ShortList<IRType*> permittedTypes;
+ bool isUnsupported = false;
+ };
+
+ struct LegalizeWGSLEntryPointContext
+ {
+ LegalizeWGSLEntryPointContext(DiagnosticSink* sink, IRModule* module) :
+ m_sink(sink), m_module(module) {}
+
+ DiagnosticSink* m_sink;
+ IRModule* m_module;
+
+ std::optional<SystemValLegalizationWorkItem> makeSystemValWorkItem(IRInst* var);
+ void legalizeSystemValue(
+ EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem
+ );
+ List<SystemValLegalizationWorkItem> collectSystemValFromEntryPoint(
+ EntryPointInfo entryPoint
+ );
+ void legalizeSystemValueParameters(EntryPointInfo entryPoint);
+ void legalizeEntryPointForWGSL(EntryPointInfo entryPoint);
+ IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType);
+ WGSLSystemValueInfo getSystemValueInfo(
+ String inSemanticName, String* optionalSemanticIndex, IRInst* parentVar
+ );
+ };
+
+ IRInst* LegalizeWGSLEntryPointContext::tryConvertValue(
+ IRBuilder& builder, IRInst* val, IRType* toType
+ )
+ {
+ auto fromType = val->getFullType();
+ if (auto fromVector = as<IRVectorType>(fromType))
+ {
+ if (auto toVector = as<IRVectorType>(toType))
+ {
+ if (fromVector->getElementCount() != toVector->getElementCount())
+ {
+ fromType =
+ builder.getVectorType(
+ fromVector->getElementType(), toVector->getElementCount()
+ );
+ val = builder.emitVectorReshape(fromType, val);
+ }
+ }
+ else if (as<IRBasicType>(toType))
+ {
+ UInt index = 0;
+ val = builder.emitSwizzle(fromVector->getElementType(), val, 1, &index);
+ if (toType->getOp() == kIROp_VoidType)
+ return nullptr;
+ }
+ }
+ else if (auto fromBasicType = as<IRBasicType>(fromType))
+ {
+ if (fromBasicType->getOp() == kIROp_VoidType)
+ return nullptr;
+ if (!as<IRBasicType>(toType))
+ return nullptr;
+ if (toType->getOp() == kIROp_VoidType)
+ return nullptr;
+ }
+ else
+ {
+ return nullptr;
+ }
+ return builder.emitCast(toType, val);
+ }
+
+
+ WGSLSystemValueInfo LegalizeWGSLEntryPointContext::getSystemValueInfo(
+ String inSemanticName, String* optionalSemanticIndex, IRInst* parentVar
+ )
+ {
+ IRBuilder builder(m_module);
+ WGSLSystemValueInfo result = {};
+ UnownedStringSlice semanticName;
+ UnownedStringSlice semanticIndex;
+
+ auto hasExplicitIndex =
+ splitNameAndIndex(
+ inSemanticName.getUnownedSlice(), semanticName, semanticIndex
+ );
+ if (!hasExplicitIndex && optionalSemanticIndex)
+ semanticIndex = optionalSemanticIndex->getUnownedSlice();
+
+ result.wgslSystemValueNameEnum =
+ convertSystemValueSemanticNameToEnum(semanticName);
+
+ switch (result.wgslSystemValueNameEnum)
+ {
+
+ case SystemValueSemanticName::DispatchThreadID:
+ {
+ result.wgslSystemValueName = toSlice("global_invocation_id");
+ IRType *const vec3uType {
+ builder.getVectorType(
+ builder.getBasicType(BaseType::UInt),
+ builder.getIntValue(builder.getIntType(), 3)
+ )
+ };
+ result.permittedTypes.add(vec3uType);
+ }
+ break;
+
+ case SystemValueSemanticName::GroupID:
+ {
+ result.wgslSystemValueName = toSlice("workgroup_id");
+ result.permittedTypes.add(
+ builder.getVectorType(
+ builder.getBasicType(BaseType::UInt),
+ builder.getIntValue(builder.getIntType(), 3)
+ )
+ );
+ }
+ break;
+
+ case SystemValueSemanticName::GroupThreadID:
+ {
+ result.wgslSystemValueName = toSlice("local_invocation_id");
+ result.permittedTypes.add(
+ builder.getVectorType(
+ builder.getBasicType(BaseType::UInt),
+ builder.getIntValue(builder.getIntType(), 3)
+ )
+ );
+ }
+ break;
+
+ case SystemValueSemanticName::GSInstanceID:
+ {
+ // No Geometry shaders in WGSL
+ result.isUnsupported = true;
+ }
+ break;
+
+ default:
+ {
+ m_sink->diagnose(
+ parentVar,
+ Diagnostics::unimplementedSystemValueSemantic, semanticName
+ );
+ return result;
+ }
+
+ }
+
+ return result;
+ }
+
+ std::optional<SystemValLegalizationWorkItem>
+ LegalizeWGSLEntryPointContext::makeSystemValWorkItem(IRInst* var)
+ {
+ if (auto semanticDecoration = var->findDecoration<IRSemanticDecoration>())
+ {
+ bool svPrefix =
+ semanticDecoration->getSemanticName().startsWithCaseInsensitive(
+ toSlice("sv_")
+ );
+ if (svPrefix)
+ {
+ return
+ {
+ {
+ var,
+ String(semanticDecoration->getSemanticName()).toLower(),
+ (UInt)semanticDecoration->getSemanticIndex()
+ }
+ };
+ }
+ }
+
+ auto layoutDecor = var->findDecoration<IRLayoutDecoration>();
+ if (!layoutDecor)
+ return {};
+ auto sysValAttr = layoutDecor->findAttr<IRSystemValueSemanticAttr>();
+ if (!sysValAttr)
+ return {};
+ auto semanticName = String(sysValAttr->getName());
+ auto sysAttrIndex = sysValAttr->getIndex();
+
+ return { { var, semanticName, sysAttrIndex } };
+ }
+
+ List<SystemValLegalizationWorkItem>
+ LegalizeWGSLEntryPointContext::collectSystemValFromEntryPoint(
+ EntryPointInfo entryPoint
+ )
+ {
+ List<SystemValLegalizationWorkItem> systemValWorkItems;
+ for (auto param : entryPoint.entryPointFunc->getParams())
+ {
+ auto maybeWorkItem = makeSystemValWorkItem(param);
+ if (maybeWorkItem.has_value())
+ systemValWorkItems.add(std::move(maybeWorkItem.value()));
+ }
+ return systemValWorkItems;
+ }
+
+ void
+ LegalizeWGSLEntryPointContext::legalizeSystemValue(
+ EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem
+ )
+ {
+ IRBuilder builder(entryPoint.entryPointFunc);
+
+ auto var = workItem.var;
+ auto semanticName = workItem.attrName;
+
+ auto indexAsString = String(workItem.attrIndex);
+ auto info = getSystemValueInfo(semanticName, &indexAsString, var);
+
+ if (!info.permittedTypes.getCount())
+ return;
+
+ builder.addTargetSystemValueDecoration(
+ var, info.wgslSystemValueName.getUnownedSlice()
+ );
+
+ bool varTypeIsPermitted = false;
+ auto varType = var->getFullType();
+ for (auto& permittedType : info.permittedTypes)
+ {
+ varTypeIsPermitted = varTypeIsPermitted || permittedType == varType;
+ }
+
+ if (!varTypeIsPermitted)
+ {
+ // Note: we do not currently prefer any conversion
+ // example:
+ // * allowed types for semantic: `float4`, `uint4`, `int4`
+ // * user used, `float2`
+ // * Slang will equally prefer `float4` to `uint4` to `int4`.
+ // This means the type may lose data if slang selects `uint4` or `int4`.
+ bool foundAConversion = false;
+ for (auto permittedType : info.permittedTypes)
+ {
+ var->setFullType(permittedType);
+ builder.setInsertBefore(
+ entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()
+ );
+
+ // get uses before we `tryConvertValue` since this creates a new use
+ List<IRUse*> uses;
+ for (auto use = var->firstUse; use; use = use->nextUse)
+ uses.add(use);
+
+ auto convertedValue = tryConvertValue(builder, var, varType);
+ if (convertedValue == nullptr)
+ continue;
+
+ foundAConversion = true;
+ copyNameHintAndDebugDecorations(convertedValue, var);
+
+ for (auto use : uses)
+ builder.replaceOperand(use, convertedValue);
+ }
+ if (!foundAConversion)
+ {
+ // If we can't convert the value, report an error.
+ for (auto permittedType : info.permittedTypes)
+ {
+ StringBuilder typeNameSB;
+ getTypeNameHint(typeNameSB, permittedType);
+ m_sink->diagnose(
+ var->sourceLoc,
+ Diagnostics::systemValueTypeIncompatible,
+ semanticName,
+ typeNameSB.produceString()
+ );
+ }
+ }
+ }
+ }
+
+ void LegalizeWGSLEntryPointContext::legalizeSystemValueParameters(
+ EntryPointInfo entryPoint
+ )
+ {
+ List<SystemValLegalizationWorkItem> systemValWorkItems =
+ collectSystemValFromEntryPoint(entryPoint);
+
+ for (auto index = 0; index < systemValWorkItems.getCount(); index++)
+ {
+ legalizeSystemValue(entryPoint, systemValWorkItems[index]);
+ }
+ }
+
+ void LegalizeWGSLEntryPointContext::legalizeEntryPointForWGSL(
+ EntryPointInfo entryPoint
+ )
+ {
+ legalizeSystemValueParameters(entryPoint);
+ }
+
+ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink)
+ {
+ List<EntryPointInfo> entryPoints;
+ for (auto inst : module->getGlobalInsts())
+ {
+ IRFunc *const func {as<IRFunc>(inst)};
+ if (!func)
+ continue;
+ IREntryPointDecoration *const entryPointDecor =
+ func->findDecoration<IREntryPointDecoration>();
+ if (!entryPointDecor)
+ continue;
+ EntryPointInfo info;
+ info.entryPointDecor = entryPointDecor;
+ info.entryPointFunc = func;
+ entryPoints.add(info);
+ }
+
+ LegalizeWGSLEntryPointContext context(sink, module);
+ for (auto entryPoint : entryPoints)
+ context.legalizeEntryPointForWGSL(entryPoint);
+ }
+
+}
diff --git a/source/slang/slang-ir-wgsl-legalize.h b/source/slang/slang-ir-wgsl-legalize.h
new file mode 100644
index 000000000..462f93204
--- /dev/null
+++ b/source/slang/slang-ir-wgsl-legalize.h
@@ -0,0 +1,10 @@
+#pragma once
+
+#include "slang-ir.h"
+
+namespace Slang
+{
+ class DiagnosticSink;
+
+ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink);
+}
diff --git a/source/slang/slang-profile.h b/source/slang/slang-profile.h
index 04d4f5112..178fbddd5 100644
--- a/source/slang/slang-profile.h
+++ b/source/slang/slang-profile.h
@@ -19,6 +19,7 @@ namespace Slang
CUDA = SLANG_SOURCE_LANGUAGE_CUDA,
SPIRV = SLANG_SOURCE_LANGUAGE_SPIRV,
Metal = SLANG_SOURCE_LANGUAGE_METAL,
+ WGSL = SLANG_SOURCE_LANGUAGE_WGSL,
CountOf = SLANG_SOURCE_LANGUAGE_COUNT_OF,
};
diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp
index f654135a1..2447f5787 100644
--- a/source/slang/slang-type-layout.cpp
+++ b/source/slang/slang-type-layout.cpp
@@ -1831,6 +1831,7 @@ LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targe
case CodeGenTarget::GLSL:
case CodeGenTarget::SPIRV:
case CodeGenTarget::SPIRVAssembly:
+ case CodeGenTarget::WGSL:
return &kGLSLLayoutRulesFamilyImpl;
case CodeGenTarget::HostHostCallable:
@@ -2141,6 +2142,10 @@ SourceLanguage getIntermediateSourceLanguageForTarget(TargetProgram* targetProgr
{
return SourceLanguage::Metal;
}
+ case CodeGenTarget::WGSL:
+ {
+ return SourceLanguage::WGSL;
+ }
case CodeGenTarget::CSource:
{
return SourceLanguage::C;
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 91ed3de5f..c78348a86 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -1838,6 +1838,10 @@ CapabilitySet TargetRequest::getTargetCaps()
atoms.add(CapabilityName::metal);
break;
+ case CodeGenTarget::WGSL:
+ atoms.add(CapabilityName::wgsl);
+ break;
+
default:
break;
}