summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--build/visual-studio/slang/slang.vcxproj2
-rw-r--r--build/visual-studio/slang/slang.vcxproj.filters6
-rw-r--r--slang.h3
-rw-r--r--source/compiler-core/slang-artifact-desc-util.cpp1
-rw-r--r--source/core/slang-type-text-util.cpp1
-rw-r--r--source/slang/slang-capabilities.capdef1
-rw-r--r--source/slang/slang-compiler.cpp7
-rwxr-xr-xsource/slang/slang-compiler.h2
-rw-r--r--source/slang/slang-doc-markdown-writer.cpp5
-rw-r--r--source/slang/slang-emit-c-like.cpp5
-rw-r--r--source/slang/slang-emit-metal.cpp661
-rw-r--r--source/slang/slang-emit-metal.h81
-rw-r--r--source/slang/slang-emit.cpp6
-rw-r--r--source/slang/slang-ir-link.cpp1
-rw-r--r--source/slang/slang-ir-util.cpp7
-rw-r--r--source/slang/slang-ir-util.h2
-rw-r--r--source/slang/slang-profile.h1
-rw-r--r--source/slang/slang-type-layout.cpp5
-rw-r--r--source/slang/slang.cpp4
-rw-r--r--tests/metal/simple-compute.slang10
-rw-r--r--tools/slang-test/slang-test-main.cpp1
21 files changed, 809 insertions, 3 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index 45f7abb80..f2d470d93 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -342,6 +342,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClInclude Include="..\..\..\source\slang\slang-emit-cuda.h" />
<ClInclude Include="..\..\..\source\slang\slang-emit-glsl.h" />
<ClInclude Include="..\..\..\source\slang\slang-emit-hlsl.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-emit-metal.h" />
<ClInclude Include="..\..\..\source\slang\slang-emit-precedence.h" />
<ClInclude Include="..\..\..\source\slang\slang-emit-source-writer.h" />
<ClInclude Include="..\..\..\source\slang\slang-emit-spirv-ops-debug-info-ext.h" />
@@ -575,6 +576,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClCompile Include="..\..\..\source\slang\slang-emit-cuda.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-emit-glsl.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-emit-hlsl.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-emit-metal.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-emit-precedence.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-emit-source-writer.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-emit-spirv.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index 9403d3649..ab06900ba 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -114,6 +114,9 @@
<ClInclude Include="..\..\..\source\slang\slang-emit-hlsl.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-emit-metal.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-emit-precedence.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -809,6 +812,9 @@
<ClCompile Include="..\..\..\source\slang\slang-emit-hlsl.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-emit-metal.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-emit-precedence.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/slang.h b/slang.h
index 77e9d3bd9..7ee7c2697 100644
--- a/slang.h
+++ b/slang.h
@@ -609,6 +609,7 @@ extern "C"
SLANG_HOST_CPP_SOURCE, ///< C++ code for host library or executable.
SLANG_HOST_HOST_CALLABLE, ///< Host callable host code (ie non kernel/shader)
SLANG_CPP_PYTORCH_BINDING, ///< C++ PyTorch binding code.
+ SLANG_METAL, ///< Metal shading language
SLANG_TARGET_COUNT_OF,
};
@@ -641,6 +642,7 @@ extern "C"
SLANG_PASS_THROUGH_NVRTC, ///< NVRTC Cuda compiler
SLANG_PASS_THROUGH_LLVM, ///< LLVM 'compiler' - includes LLVM and Clang
SLANG_PASS_THROUGH_SPIRV_OPT, ///< SPIRV-opt
+ SLANG_PASS_THROUGH_METAL, ///< Metal compiler
SLANG_PASS_THROUGH_COUNT_OF,
};
@@ -743,6 +745,7 @@ extern "C"
SLANG_SOURCE_LANGUAGE_CPP,
SLANG_SOURCE_LANGUAGE_CUDA,
SLANG_SOURCE_LANGUAGE_SPIRV,
+ SLANG_SOURCE_LANGUAGE_METAL,
SLANG_SOURCE_LANGUAGE_COUNT_OF,
};
diff --git a/source/compiler-core/slang-artifact-desc-util.cpp b/source/compiler-core/slang-artifact-desc-util.cpp
index 559a1b4ca..2646555bc 100644
--- a/source/compiler-core/slang-artifact-desc-util.cpp
+++ b/source/compiler-core/slang-artifact-desc-util.cpp
@@ -286,6 +286,7 @@ SLANG_HIERARCHICAL_ENUM(ArtifactStyle, SLANG_ARTIFACT_STYLE, SLANG_ARTIFACT_STYL
case SLANG_PTX: return Desc::make(Kind::Executable, Payload::PTX, Style::Kernel, 0);
case SLANG_OBJECT_CODE: return Desc::make(Kind::ObjectCode, Payload::HostCPU, Style::Kernel, 0);
case SLANG_HOST_HOST_CALLABLE: return Desc::make(Kind::HostCallable, Payload::HostCPU, Style::Host, 0);
+ case SLANG_METAL: return Desc::make(Kind::Source, Payload::Metal, Style::Kernel, 0);
default: break;
}
diff --git a/source/core/slang-type-text-util.cpp b/source/core/slang-type-text-util.cpp
index 676be3976..7e1ffc439 100644
--- a/source/core/slang-type-text-util.cpp
+++ b/source/core/slang-type-text-util.cpp
@@ -61,6 +61,7 @@ static const TypeTextUtil::CompileTargetInfo s_compileTargetInfos[] =
{ SLANG_SHADER_HOST_CALLABLE, "", "host-callable,callable", "Host callable" },
{ SLANG_OBJECT_CODE, "obj,o", "object-code", "Object code" },
{ SLANG_HOST_HOST_CALLABLE, "", "host-host-callable", "Host callable for host execution" },
+ { SLANG_METAL, "metal", "metal", "Metal shader source"},
};
static const NamesDescriptionValue s_languageInfos[] =
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index b83585c5c..eb546ae6b 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -45,6 +45,7 @@ def glsl : target + textualTarget;
def c : target + textualTarget;
def cpp : target + textualTarget;
def cuda : target + textualTarget;
+def metal : target + textualTarget;
// We have multiple capabilities for the various SPIR-V versions,
// arranged so that they inherit from one another to represent which versions
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 3f5351ff4..04536c81a 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -480,7 +480,10 @@ namespace Slang
{
return SourceLanguage::SPIRV;
}
-
+ case PassThroughMode::MetalC:
+ {
+ return SourceLanguage::Metal;
+ }
default: break;
}
SLANG_ASSERT(!"Unknown compiler");
@@ -499,6 +502,7 @@ namespace Slang
case CodeGenTarget::HostCPPSource:
case CodeGenTarget::PyTorchCppBinding:
case CodeGenTarget::CSource:
+ case CodeGenTarget::Metal:
{
return PassThroughMode::None;
}
@@ -1617,6 +1621,7 @@ namespace Slang
case CodeGenTarget::HostCPPSource:
case CodeGenTarget::PyTorchCppBinding:
case CodeGenTarget::CSource:
+ case CodeGenTarget::Metal:
{
RefPtr<ExtensionTracker> extensionTracker = _newExtensionTracker(target);
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 014b678f5..7f7903f8b 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -92,6 +92,7 @@ namespace Slang
CUDAObjectCode = SLANG_CUDA_OBJECT_CODE,
ObjectCode = SLANG_OBJECT_CODE,
HostHostCallable = SLANG_HOST_HOST_CALLABLE,
+ Metal = SLANG_METAL,
CountOf = SLANG_TARGET_COUNT_OF,
};
@@ -1236,6 +1237,7 @@ namespace Slang
NVRTC = SLANG_PASS_THROUGH_NVRTC, ///< NVRTC CUDA compiler
LLVM = SLANG_PASS_THROUGH_LLVM, ///< LLVM 'compiler'
SpirvOpt = SLANG_PASS_THROUGH_SPIRV_OPT, ///< pass thorugh spirv to spirv-opt
+ MetalC = SLANG_PASS_THROUGH_METAL,
CountOf = SLANG_PASS_THROUGH_COUNT_OF,
};
void printDiagnosticArg(StringBuilder& sb, PassThroughMode val);
diff --git a/source/slang/slang-doc-markdown-writer.cpp b/source/slang/slang-doc-markdown-writer.cpp
index 77b239437..ac3b9ca7e 100644
--- a/source/slang/slang-doc-markdown-writer.cpp
+++ b/source/slang/slang-doc-markdown-writer.cpp
@@ -479,7 +479,10 @@ static DocMarkdownWriter::Requirement _getRequirementFromTargetToken(const Token
{
return Requirement{ CodeGenTarget::CSource, targetName };
}
-
+ else if (isCapabilityDerivedFrom(targetCap, CapabilityAtom::metal))
+ {
+ return Requirement{ CodeGenTarget::Metal, targetName };
+ }
return Requirement{ CodeGenTarget::Unknown, String() };
}
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index ceb4e6de0..7cb4871be 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -91,6 +91,10 @@ struct CLikeSourceEmitter::ComputeEmitActionsContext
{
return SourceLanguage::CUDA;
}
+ case CodeGenTarget::Metal:
+ {
+ return SourceLanguage::Metal;
+ }
}
}
@@ -3699,7 +3703,6 @@ void CLikeSourceEmitter::emitVarModifiers(IRVarLayout* layout, IRInst* varDecl,
{
// TODO(JS): We could push all of this onto the target impls, and then not need so many virtual hooks.
emitVarDecorationsImpl(varDecl);
-
emitTempModifiers(varDecl);
if (!layout)
diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp
new file mode 100644
index 000000000..1ce25a2da
--- /dev/null
+++ b/source/slang/slang-emit-metal.cpp
@@ -0,0 +1,661 @@
+// slang-emit-metal.cpp
+#include "slang-emit-metal.h"
+
+#include "../core/slang-writer.h"
+
+#include "slang-ir-util.h"
+#include "slang-emit-source-writer.h"
+#include "slang-mangled-lexer.h"
+
+#include <assert.h>
+
+namespace Slang {
+
+void MetalSourceEmitter::_emitHLSLDecorationSingleString(const char* name, IRFunc* entryPoint, IRStringLit* val)
+{
+ SLANG_UNUSED(entryPoint);
+ assert(val);
+
+ m_writer->emit("[[");
+ m_writer->emit(name);
+ m_writer->emit("(\"");
+ m_writer->emit(val->getStringSlice());
+ m_writer->emit("\")]]\n");
+}
+
+void MetalSourceEmitter::_emitHLSLDecorationSingleInt(const char* name, IRFunc* entryPoint, IRIntLit* val)
+{
+ SLANG_UNUSED(entryPoint);
+ SLANG_ASSERT(val);
+
+ auto intVal = getIntVal(val);
+
+ m_writer->emit("[[");
+ m_writer->emit(name);
+ m_writer->emit("(");
+ m_writer->emit(intVal);
+ m_writer->emit(")]]\n");
+}
+
+void MetalSourceEmitter::_emitHLSLRegisterSemantic(LayoutResourceKind kind, EmitVarChain* chain, IRInst* inst, char const* uniformSemanticSpelling)
+{
+ // Metal does not use explicit binding.
+ SLANG_UNUSED(kind);
+ SLANG_UNUSED(chain);
+ SLANG_UNUSED(inst);
+ SLANG_UNUSED(uniformSemanticSpelling);
+}
+
+void MetalSourceEmitter::_emitHLSLRegisterSemantics(EmitVarChain* chain, IRInst* inst, char const* uniformSemanticSpelling)
+{
+ // TODO: implement.
+ SLANG_UNUSED(chain);
+ SLANG_UNUSED(inst);
+ SLANG_UNUSED(uniformSemanticSpelling);
+}
+
+void MetalSourceEmitter::_emitHLSLRegisterSemantics(IRVarLayout* varLayout, IRInst* inst, char const* uniformSemanticSpelling)
+{
+ // TODO: implement.
+ SLANG_UNUSED(varLayout);
+ SLANG_UNUSED(inst);
+ SLANG_UNUSED(uniformSemanticSpelling);
+}
+
+void MetalSourceEmitter::_emitHLSLParameterGroupFieldLayoutSemantics(EmitVarChain* chain)
+{
+ // TODO: implement.
+ SLANG_UNUSED(chain);
+}
+
+void MetalSourceEmitter::_emitHLSLParameterGroupFieldLayoutSemantics(IRVarLayout* fieldLayout, EmitVarChain* inChain)
+{
+ // TODO: implement.
+ SLANG_UNUSED(fieldLayout);
+ SLANG_UNUSED(inChain);
+}
+
+void MetalSourceEmitter::_emitHLSLParameterGroup(IRGlobalParam* varDecl, IRUniformParameterGroupType* type)
+{
+ // Metal does not allow shader parameters declared as global variables, so we shouldn't see this.
+ SLANG_UNUSED(varDecl);
+ SLANG_UNUSED(type);
+ SLANG_ASSERT(!"Metal does not allow shader parameters declared as global variables.");
+}
+
+void MetalSourceEmitter::_emitHLSLTextureType(IRTextureTypeBase* texType)
+{
+ if (getIntVal(texType->getIsShadowInst()) != 0)
+ {
+ m_writer->emit("depth");
+ }
+ else
+ {
+ m_writer->emit("texture");
+ }
+
+ switch (texType->GetBaseShape())
+ {
+ case SLANG_TEXTURE_1D: m_writer->emit("1d"); break;
+ case SLANG_TEXTURE_2D: m_writer->emit("2d"); break;
+ case SLANG_TEXTURE_3D: m_writer->emit("3d"); break;
+ case SLANG_TEXTURE_CUBE: m_writer->emit("cube"); break;
+ case SLANG_TEXTURE_BUFFER: m_writer->emit("1d"); break;
+ default:
+ SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled resource shape");
+ break;
+ }
+
+ if (texType->isMultisample())
+ {
+ m_writer->emit("_ms");
+ }
+ if (texType->isArray())
+ {
+ m_writer->emit("_array");
+ }
+ m_writer->emit("<");
+ emitType(getVectorElementType(texType->getElementType()));
+ m_writer->emit(", ");
+
+ switch (texType->getAccess())
+ {
+ case SLANG_RESOURCE_ACCESS_READ:
+ m_writer->emit("access::sample");
+ break;
+
+ case SLANG_RESOURCE_ACCESS_READ_WRITE:
+ case SLANG_RESOURCE_ACCESS_APPEND:
+ case SLANG_RESOURCE_ACCESS_CONSUME:
+ case SLANG_RESOURCE_ACCESS_FEEDBACK:
+ case SLANG_RESOURCE_ACCESS_RASTER_ORDERED:
+ m_writer->emit("access::read_write");
+ break;
+ default:
+ SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled resource access mode");
+ break;
+ }
+
+ m_writer->emit(">");
+}
+
+void MetalSourceEmitter::_emitHLSLSubpassInputType(IRSubpassInputType* subpassType)
+{
+ SLANG_UNUSED(subpassType);
+}
+
+void MetalSourceEmitter::emitLayoutSemanticsImpl(IRInst* inst, char const* uniformSemanticSpelling)
+{
+ auto layout = getVarLayout(inst);
+ if (layout)
+ {
+ _emitHLSLRegisterSemantics(layout, inst, uniformSemanticSpelling);
+ }
+}
+
+void MetalSourceEmitter::emitParameterGroupImpl(IRGlobalParam* varDecl, IRUniformParameterGroupType* type)
+{
+ _emitHLSLParameterGroup(varDecl, type);
+}
+
+void MetalSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPointDecoration* entryPointDecor)
+{
+ auto profile = m_effectiveProfile;
+ auto stage = entryPointDecor->getProfile().getStage();
+
+ switch (stage)
+ {
+ case Stage::Fragment:
+ m_writer->emit("[[fragment]] ");
+ break;
+ case Stage::Vertex:
+ m_writer->emit("[[vertex]] ");
+ break;
+ case Stage::Compute:
+ m_writer->emit("[[kernel]] ");
+ break;
+ default:
+ SLANG_ABORT_COMPILATION("unsupported stage.");
+ }
+
+ switch (stage)
+ {
+ case Stage::Pixel:
+ {
+ if (irFunc->findDecoration<IREarlyDepthStencilDecoration>())
+ {
+ m_writer->emit("[[early_fragment_tests]]\n");
+ }
+ break;
+ }
+ default:
+ break;
+ }
+}
+
+bool MetalSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_MakeVector:
+ case kIROp_MakeMatrix:
+ {
+ if (inst->getOperandCount() == 1)
+ {
+ EmitOpInfo outerPrec = inOuterPrec;
+ bool needClose = false;
+
+ auto prec = getInfo(EmitOp::Prefix);
+ needClose = maybeEmitParens(outerPrec, prec);
+
+ // Need to emit as cast for HLSL
+ emitType(inst->getDataType());
+ m_writer->emit("(");
+ emitOperand(inst->getOperand(0), rightSide(outerPrec, prec));
+ m_writer->emit(") ");
+
+ maybeCloseParens(needClose);
+ // Handled
+ return true;
+ }
+ break;
+ }
+ case kIROp_BitCast:
+ {
+ auto toType = inst->getDataType();
+
+ m_writer->emit("as_type<");
+ emitType(toType);
+ m_writer->emit(">(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
+ case kIROp_StringLit:
+ {
+ const auto handler = StringEscapeUtil::getHandler(StringEscapeUtil::Style::Slang);
+
+ StringBuilder buf;
+ const UnownedStringSlice slice = as<IRStringLit>(inst)->getStringSlice();
+ StringEscapeUtil::appendQuoted(handler, slice, buf);
+
+ m_writer->emit(buf);
+
+ return true;
+ }
+ case kIROp_ByteAddressBufferLoad:
+ {
+ // This only works for loads of 4-byte values.
+ // Other element types should have been lowered by previous legalization passes.
+ auto elementType = inst->getDataType();
+ auto buffer = inst->getOperand(0);
+ auto offset = inst->getOperand(1);
+ m_writer->emit("as_type<");
+ emitType(elementType);
+ m_writer->emit(">(");
+ emitOperand(buffer, getInfo(EmitOp::General));
+ m_writer->emit("[(");
+ emitOperand(offset, getInfo(EmitOp::General));
+ m_writer->emit(")>>2)]");
+ return true;
+ }
+ case kIROp_ByteAddressBufferStore:
+ {
+ // This only works for loads of 4-byte values.
+ // Other element types should have been lowered by previous legalization passes.
+ auto buffer = inst->getOperand(0);
+ auto offset = inst->getOperand(1);
+ emitOperand(buffer, getInfo(EmitOp::General));
+ m_writer->emit("[(");
+ emitOperand(offset, getInfo(EmitOp::General));
+ m_writer->emit(")>>2)] = as_type<uint32_t>(");
+ emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
+ break;
+
+ default: break;
+ }
+ // Not handled
+ return false;
+}
+
+void MetalSourceEmitter::emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount)
+{
+ // In some cases we *need* to use the built-in syntax sugar for vector types,
+ // so we will try to emit those whenever possible.
+ //
+ if( elementCount >= 1 && elementCount <= 4 )
+ {
+ switch( elementType->getOp() )
+ {
+ case kIROp_FloatType:
+ case kIROp_IntType:
+ case kIROp_UIntType:
+ // TODO: There are more types that need to be covered here
+ emitType(elementType);
+ m_writer->emit(elementCount);
+ return;
+
+ default:
+ break;
+ }
+ }
+
+ // As a fallback, we will use the `vector<...>` type constructor,
+ // although we should not expect to run into types that don't
+ // have a sugared form.
+ //
+ m_writer->emit("vector<");
+ emitType(elementType);
+ m_writer->emit(",");
+ m_writer->emit(elementCount);
+ m_writer->emit(">");
+}
+
+void MetalSourceEmitter::emitLoopControlDecorationImpl(IRLoopControlDecoration* decl)
+{
+ switch (decl->getMode())
+ {
+ case kIRLoopControl_Unroll:
+ m_writer->emit("[unroll]\n");
+ break;
+ case kIRLoopControl_Loop:
+ m_writer->emit("[loop]\n");
+ break;
+ default:
+ break;
+ }
+}
+
+static bool _canEmitExport(const Profile& profile)
+{
+ const auto family = profile.getFamily();
+ const auto version = profile.getVersion();
+ // Is ita late enough version of shader model to output with 'export'
+ return (family == ProfileFamily::DX && version >= ProfileVersion::DX_6_1);
+}
+
+/* virtual */void MetalSourceEmitter::emitFuncDecorationsImpl(IRFunc* func)
+{
+ // Specially handle export, as we don't want to emit it multiple times
+ if (getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::GenerateWholeProgram) &&
+ _canEmitExport(m_effectiveProfile))
+ {
+ for (auto decoration : func->getDecorations())
+ {
+ const auto op = decoration->getOp();
+ if (op == kIROp_PublicDecoration ||
+ op == kIROp_HLSLExportDecoration)
+ {
+ m_writer->emit("export\n");
+ break;
+ }
+ }
+ }
+
+ // Use the default for others
+ Super::emitFuncDecorationsImpl(func);
+}
+
+void MetalSourceEmitter::emitIfDecorationsImpl(IRIfElse* ifInst)
+{
+ // Does not apply to metal.
+ SLANG_UNUSED(ifInst);
+}
+
+void MetalSourceEmitter::emitSwitchDecorationsImpl(IRSwitch* switchInst)
+{
+ // Does not apply to metal.
+ SLANG_UNUSED(switchInst);
+}
+
+void MetalSourceEmitter::emitFuncDecorationImpl(IRDecoration* decoration)
+{
+ // Does not apply to metal.
+ SLANG_UNUSED(decoration);
+}
+
+void MetalSourceEmitter::emitSimpleValueImpl(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_FloatLit:
+ {
+ IRConstant* constantInst = static_cast<IRConstant*>(inst);
+ IRConstant::FloatKind kind = constantInst->getFloatKind();
+ switch (kind)
+ {
+ case IRConstant::FloatKind::Nan:
+ {
+ m_writer->emit("(0.0 / 0.0)");
+ return;
+ }
+ case IRConstant::FloatKind::PositiveInfinity:
+ {
+ m_writer->emit("(1.0 / 0.0)");
+ return;
+ }
+ case IRConstant::FloatKind::NegativeInfinity:
+ {
+ m_writer->emit("(-1.0 / 0.0)");
+ return;
+ }
+ default: break;
+ }
+ break;
+ }
+
+ default: break;
+ }
+
+ Super::emitSimpleValueImpl(inst);
+}
+
+void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
+{
+ switch (type->getOp())
+ {
+ case kIROp_VoidType:
+ case kIROp_BoolType:
+ case kIROp_Int8Type:
+ case kIROp_IntType:
+ case kIROp_Int64Type:
+ case kIROp_UInt8Type:
+ case kIROp_UIntType:
+ case kIROp_UInt64Type:
+ case kIROp_FloatType:
+ case kIROp_DoubleType:
+ case kIROp_Int16Type:
+ case kIROp_UInt16Type:
+ case kIROp_HalfType:
+ {
+ m_writer->emit(getDefaultBuiltinTypeName(type->getOp()));
+ return;
+ }
+ case kIROp_IntPtrType:
+ m_writer->emit("int64_t");
+ return;
+ case kIROp_UIntPtrType:
+ m_writer->emit("uint64_t");
+ return;
+ case kIROp_StructType:
+ m_writer->emit(getName(type));
+ return;
+
+ case kIROp_VectorType:
+ {
+ auto vecType = (IRVectorType*)type;
+ emitVectorTypeNameImpl(vecType->getElementType(), getIntVal(vecType->getElementCount()));
+ return;
+ }
+ case kIROp_MatrixType:
+ {
+ auto matType = (IRMatrixType*)type;
+
+ // Similar to GLSL, Metal's column-major is really our row-major.
+ m_writer->emit("matrix<");
+ emitType(matType->getElementType());
+ m_writer->emit(",");
+ emitVal(matType->getColumnCount(), getInfo(EmitOp::General));
+ m_writer->emit(",");
+ emitVal(matType->getRowCount(), getInfo(EmitOp::General));
+ m_writer->emit("> ");
+ return;
+ }
+ case kIROp_SamplerStateType:
+ case kIROp_SamplerComparisonStateType:
+ {
+ m_writer->emit("sampler");
+ return;
+ }
+ case kIROp_NativeStringType:
+ case kIROp_StringType:
+ {
+ m_writer->emit("int");
+ return;
+ }
+ case kIROp_ParameterBlockType:
+ case kIROp_ConstantBufferType:
+ {
+ m_writer->emit("constant ");
+ emitType((IRType*)type->getOperand(0));
+ m_writer->emit("*");
+ return;
+ }
+ default: break;
+ }
+
+ if (auto texType = as<IRTextureType>(type))
+ {
+ _emitHLSLTextureType(texType);
+ return;
+ }
+ else if (auto imageType = as<IRGLSLImageType>(type))
+ {
+ _emitHLSLTextureType(imageType);
+ return;
+ }
+ else if (auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type))
+ {
+ m_writer->emit("device ");
+ emitType(structuredBufferType->getElementType());
+ m_writer->emit("*");
+ return;
+ }
+ else if (const auto untypedBufferType = as<IRUntypedBufferResourceType>(type))
+ {
+ switch (type->getOp())
+ {
+ case kIROp_HLSLByteAddressBufferType:
+ case kIROp_HLSLRWByteAddressBufferType:
+ case kIROp_HLSLRasterizerOrderedByteAddressBufferType:
+ m_writer->emit("device ");
+ m_writer->emit("uint32_t *");
+ break;
+ case kIROp_RaytracingAccelerationStructureType: m_writer->emit("acceleration_structure<instancing>"); break;
+ default:
+ SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled buffer type");
+ break;
+ }
+ return;
+ }
+ else if(auto specializedType = as<IRSpecialize>(type))
+ {
+ // If a `specialize` instruction made it this far, then
+ // it represents an intrinsic generic type.
+ //
+ emitSimpleType((IRType*) getSpecializedValue(specializedType));
+ m_writer->emit("<");
+ UInt argCount = specializedType->getArgCount();
+ for (UInt ii = 0; ii < argCount; ++ii)
+ {
+ if (ii != 0) m_writer->emit(", ");
+ emitVal(specializedType->getArg(ii), getInfo(EmitOp::General));
+ }
+ m_writer->emit(" >");
+ return;
+ }
+
+ // HACK: As a fallback for HLSL targets, assume that the name of the
+ // instruction being used is the same as the name of the HLSL type.
+ {
+ auto opInfo = getIROpInfo(type->getOp());
+ m_writer->emit(opInfo.name);
+ UInt operandCount = type->getOperandCount();
+ if (operandCount)
+ {
+ m_writer->emit("<");
+ for (UInt ii = 0; ii < operandCount; ++ii)
+ {
+ if (ii != 0) m_writer->emit(", ");
+ emitVal(type->getOperand(ii), getInfo(EmitOp::General));
+ }
+ m_writer->emit(" >");
+ }
+ }
+}
+
+void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] IRIntegerValue addressSpace)
+{
+ if (as<IRGroupSharedRate>(rate))
+ {
+ m_writer->emit("threadgroup ");
+ }
+}
+
+void MetalSourceEmitter::emitSemanticsImpl(IRInst* inst, bool allowOffsets)
+{
+ // Metal does not use semantics.
+ SLANG_UNUSED(inst);
+ SLANG_UNUSED(allowOffsets);
+}
+
+void MetalSourceEmitter::_emitStageAccessSemantic(IRStageAccessDecoration* decoration, const char* name)
+{
+ SLANG_UNUSED(decoration);
+ SLANG_UNUSED(name);
+}
+
+void MetalSourceEmitter::emitSimpleFuncParamImpl(IRParam* param)
+{
+ Super::emitSimpleFuncParamImpl(param);
+}
+
+static UnownedStringSlice _getInterpolationModifierText(IRInterpolationMode mode)
+{
+ switch (mode)
+ {
+ case IRInterpolationMode::PerVertex:
+ case IRInterpolationMode::NoInterpolation: return UnownedStringSlice::fromLiteral("[[flat]]");
+ case IRInterpolationMode::NoPerspective: return UnownedStringSlice::fromLiteral("[[center_no_perspective]]");
+ case IRInterpolationMode::Linear: return UnownedStringSlice::fromLiteral("[[sample_no_perspective]]");
+ case IRInterpolationMode::Sample: return UnownedStringSlice::fromLiteral("[[sample_perspective]]");
+ case IRInterpolationMode::Centroid: return UnownedStringSlice::fromLiteral("[[center_perspective]]");
+ default: return UnownedStringSlice();
+ }
+}
+
+void MetalSourceEmitter::emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout)
+{
+ SLANG_UNUSED(layout);
+ SLANG_UNUSED(valueType);
+
+ for (auto dd : varInst->getDecorations())
+ {
+ if (dd->getOp() != kIROp_InterpolationModeDecoration)
+ continue;
+
+ auto decoration = (IRInterpolationModeDecoration*)dd;
+
+ UnownedStringSlice modeText = _getInterpolationModifierText(decoration->getMode());
+ if (modeText.getLength() > 0)
+ {
+ m_writer->emit(modeText);
+ m_writer->emitChar(' ');
+ }
+ }
+}
+
+void MetalSourceEmitter::emitPackOffsetModifier(IRInst* varInst, IRType* valueType, IRPackOffsetDecoration* layout)
+{
+ SLANG_UNUSED(varInst);
+ SLANG_UNUSED(valueType);
+ SLANG_UNUSED(layout);
+ // We emit packoffset as a semantic in `emitSemantic`, so nothing to do here.
+}
+
+void MetalSourceEmitter::emitMeshShaderModifiersImpl(IRInst* varInst)
+{
+ SLANG_UNUSED(varInst);
+}
+
+void MetalSourceEmitter::emitVarDecorationsImpl(IRInst* varInst)
+{
+ SLANG_UNUSED(varInst);
+}
+
+void MetalSourceEmitter::emitMatrixLayoutModifiersImpl(IRVarLayout*)
+{
+ // Metal only supports column major layout, and we must have
+ // already translated all matrix ops to assume column-major
+ // at this stage.
+}
+
+void MetalSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst)
+{
+ SLANG_UNUSED(inst);
+}
+
+void MetalSourceEmitter::emitFrontMatterImpl(TargetRequest*)
+{
+
+}
+
+void MetalSourceEmitter::emitGlobalInstImpl(IRInst* inst)
+{
+ Super::emitGlobalInstImpl(inst);
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-emit-metal.h b/source/slang/slang-emit-metal.h
new file mode 100644
index 000000000..4c4f27be3
--- /dev/null
+++ b/source/slang/slang-emit-metal.h
@@ -0,0 +1,81 @@
+// slang-emit-metal.h
+#ifndef SLANG_EMIT_METAL_H
+#define SLANG_EMIT_METAL_H
+
+#include "slang-emit-c-like.h"
+
+namespace Slang
+{
+class MetalExtensionTracker : public ExtensionTracker {};
+
+class MetalSourceEmitter : public CLikeSourceEmitter
+{
+public:
+ typedef CLikeSourceEmitter Super;
+
+ MetalSourceEmitter(const Desc& desc)
+ : Super(desc)
+ , m_extensionTracker(new MetalExtensionTracker())
+ {}
+
+ virtual RefObject* getExtensionTracker() SLANG_OVERRIDE { return m_extensionTracker; }
+
+protected:
+ RefPtr<MetalExtensionTracker> m_extensionTracker;
+
+ virtual void emitLayoutSemanticsImpl(IRInst* inst, char const* uniformSemanticSpelling) SLANG_OVERRIDE;
+ virtual void emitParameterGroupImpl(IRGlobalParam* varDecl, IRUniformParameterGroupType* type) SLANG_OVERRIDE;
+ virtual void emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE;
+
+ virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE;
+
+ virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) SLANG_OVERRIDE;
+ virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsets) SLANG_OVERRIDE;
+ virtual void emitSimpleFuncParamImpl(IRParam* param) SLANG_OVERRIDE;
+ virtual void emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout) SLANG_OVERRIDE;
+ virtual void emitPackOffsetModifier(IRInst* varInst, IRType* valueType, IRPackOffsetDecoration* decoration) SLANG_OVERRIDE;
+
+ virtual void emitMeshShaderModifiersImpl(IRInst* varInst) SLANG_OVERRIDE;
+ virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE;
+ virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE;
+ virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE;
+ virtual void emitMatrixLayoutModifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE;
+
+ virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE;
+ virtual void emitSimpleValueImpl(IRInst* inst) SLANG_OVERRIDE;
+ virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) SLANG_OVERRIDE;
+ virtual void emitFuncDecorationImpl(IRDecoration* decoration) SLANG_OVERRIDE;
+ virtual void emitFuncDecorationsImpl(IRFunc* func) SLANG_OVERRIDE;
+
+ virtual void emitSwitchDecorationsImpl(IRSwitch* switchInst) SLANG_OVERRIDE;
+ virtual void emitIfDecorationsImpl(IRIfElse* ifInst) SLANG_OVERRIDE;
+
+ virtual void handleRequiredCapabilitiesImpl(IRInst* inst) SLANG_OVERRIDE;
+
+ virtual void emitGlobalInstImpl(IRInst* inst) SLANG_OVERRIDE;
+
+ // Emit a single `register` semantic, as appropriate for a given resource-type-specific layout info
+ // Keyword to use in the uniform case (`register` for globals, `packoffset` inside a `cbuffer`)
+ void _emitHLSLRegisterSemantic(LayoutResourceKind kind, EmitVarChain* chain, IRInst* inst, char const* uniformSemanticSpelling = "register");
+
+ // Emit all the `register` semantics that are appropriate for a particular variable layout
+ void _emitHLSLRegisterSemantics(EmitVarChain* chain, IRInst* inst, char const* uniformSemanticSpelling = "register");
+ void _emitHLSLRegisterSemantics(IRVarLayout* varLayout, IRInst* inst, char const* uniformSemanticSpelling = "register");
+
+ void _emitHLSLParameterGroupFieldLayoutSemantics(EmitVarChain* chain);
+ void _emitHLSLParameterGroupFieldLayoutSemantics(IRVarLayout* fieldLayout, EmitVarChain* inChain);
+
+ void _emitHLSLParameterGroup(IRGlobalParam* varDecl, IRUniformParameterGroupType* type);
+
+ void _emitHLSLTextureType(IRTextureTypeBase* texType);
+
+ void _emitHLSLSubpassInputType(IRSubpassInputType* subpassType);
+
+ void _emitHLSLDecorationSingleString(const char* name, IRFunc* entryPoint, IRStringLit* val);
+ void _emitHLSLDecorationSingleInt(const char* name, IRFunc* entryPoint, IRIntLit* val);
+
+ void _emitStageAccessSemantic(IRStageAccessDecoration* decoration, const char* name);
+};
+
+}
+#endif
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index fbc99b0ce..633d17345 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -94,6 +94,7 @@
#include "slang-emit-glsl.h"
#include "slang-emit-hlsl.h"
+#include "slang-emit-metal.h"
#include "slang-emit-cpp.h"
#include "slang-emit-cuda.h"
#include "slang-emit-torch.h"
@@ -1171,6 +1172,11 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr
sourceEmitter = new CUDASourceEmitter(desc);
break;
}
+ case SourceLanguage::Metal:
+ {
+ sourceEmitter = new MetalSourceEmitter(desc);
+ break;
+ }
default: break;
}
break;
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index defca19fd..86e255705 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -1467,6 +1467,7 @@ static bool doesTargetAllowUnresolvedFuncSymbol(TargetRequest* req)
switch (req->getTarget())
{
case CodeGenTarget::HLSL:
+ case CodeGenTarget::Metal:
case CodeGenTarget::DXIL:
case CodeGenTarget::DXILAssembly:
case CodeGenTarget::HostCPPSource:
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 380d1141e..eaf254190 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -16,6 +16,13 @@ bool isPointerOfType(IRInst* type, IROp opCode)
return false;
}
+IRType* getVectorElementType(IRType* type)
+{
+ if (auto vectorType = as<IRVectorType>(type))
+ return vectorType->getElementType();
+ return type;
+}
+
Dictionary<IRInst*, IRInst*> buildInterfaceRequirementDict(IRInterfaceType* interfaceType)
{
Dictionary<IRInst*, IRInst*> result;
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 40ba783b9..94ae4bc9f 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -75,6 +75,8 @@ Dictionary<IRInst*, IRInst*> buildInterfaceRequirementDict(IRInterfaceType* inte
bool isComInterfaceType(IRType* type);
+// If `type` is a vector, returns its element type. Otherwise, return `type`.
+IRType* getVectorElementType(IRType* type);
IROp getTypeStyle(IROp op);
IROp getTypeStyle(BaseType op);
diff --git a/source/slang/slang-profile.h b/source/slang/slang-profile.h
index a1c08fe6a..bd6feab23 100644
--- a/source/slang/slang-profile.h
+++ b/source/slang/slang-profile.h
@@ -18,6 +18,7 @@ namespace Slang
CPP = SLANG_SOURCE_LANGUAGE_CPP,
CUDA = SLANG_SOURCE_LANGUAGE_CUDA,
SPIRV = SLANG_SOURCE_LANGUAGE_SPIRV,
+ Metal = SLANG_SOURCE_LANGUAGE_METAL,
CountOf = SLANG_SOURCE_LANGUAGE_COUNT_OF,
};
diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp
index 30c41d0fe..e79c51256 100644
--- a/source/slang/slang-type-layout.cpp
+++ b/source/slang/slang-type-layout.cpp
@@ -1537,6 +1537,7 @@ LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targe
case CodeGenTarget::ShaderSharedLibrary:
case CodeGenTarget::CPPSource:
case CodeGenTarget::CSource:
+ case CodeGenTarget::Metal:
{
// For now lets use some fairly simple CPU binding rules
@@ -1788,6 +1789,10 @@ SourceLanguage getIntermediateSourceLanguageForTarget(TargetProgram* targetProgr
// Currently DXBytecode and DXIL are generated via HLSL
return SourceLanguage::HLSL;
}
+ case CodeGenTarget::Metal:
+ {
+ return SourceLanguage::Metal;
+ }
case CodeGenTarget::CSource:
{
return SourceLanguage::C;
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 9d018d47f..437f0cd63 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -1709,6 +1709,10 @@ CapabilitySet TargetRequest::getTargetCaps()
atoms.add(CapabilityName::cuda);
break;
+ case CodeGenTarget::Metal:
+ atoms.add(CapabilityName::metal);
+ break;
+
default:
break;
}
diff --git a/tests/metal/simple-compute.slang b/tests/metal/simple-compute.slang
new file mode 100644
index 000000000..e099704be
--- /dev/null
+++ b/tests/metal/simple-compute.slang
@@ -0,0 +1,10 @@
+//TEST:SIMPLE(filecheck=CHECK): -target metal
+
+RWStructuredBuffer<float> outputBuffer;
+
+// CHECK: {{.*}}kernel{{.*}} void main()
+[numthreads(1,1,1)]
+void main()
+{
+ outputBuffer[0] = 1.0f;
+} \ No newline at end of file
diff --git a/tools/slang-test/slang-test-main.cpp b/tools/slang-test/slang-test-main.cpp
index c93bebf33..cb77c8830 100644
--- a/tools/slang-test/slang-test-main.cpp
+++ b/tools/slang-test/slang-test-main.cpp
@@ -967,6 +967,7 @@ static PassThroughFlags _getPassThroughFlagsForTarget(SlangCompileTarget target)
case SLANG_CPP_PYTORCH_BINDING:
case SLANG_HOST_CPP_SOURCE:
case SLANG_CUDA_SOURCE:
+ case SLANG_METAL:
{
return 0;
}